diff --git a/tests/common/optimization/test_float_fusing.py b/tests/common/optimization/test_float_fusing.py index 0153deaa4..062e77890 100644 --- a/tests/common/optimization/test_float_fusing.py +++ b/tests/common/optimization/test_float_fusing.py @@ -378,6 +378,7 @@ def test_fuse_float_operations( warning_message, capfd, remove_color_codes, + check_array_equality, ): """Test function for fuse_float_operations""" @@ -405,7 +406,7 @@ def test_fuse_float_operations( input_ = numpy.full(param_input_value.shape, input_, dtype=numpy.int32) inputs += (input_,) - assert numpy.array_equal(function_to_trace(*inputs), op_graph(*inputs)) + check_array_equality(function_to_trace(*inputs), op_graph(*inputs)) def subtest_tensor_no_fuse(fun, tensor_shape): diff --git a/tests/common/representation/test_intermediate.py b/tests/common/representation/test_intermediate.py index e97f19907..8a8a9e755 100644 --- a/tests/common/representation/test_intermediate.py +++ b/tests/common/representation/test_intermediate.py @@ -220,10 +220,11 @@ def test_evaluate( node: ir.IntermediateNode, input_data, expected_result: int, + check_array_equality, ): """Test evaluate methods on IntermediateNodes""" if isinstance(expected_result, numpy.ndarray): - assert numpy.array_equal(node.evaluate(input_data), expected_result) + check_array_equality(node.evaluate(input_data), expected_result) else: assert node.evaluate(input_data) == expected_result diff --git a/tests/conftest.py b/tests/conftest.py index ad4e6c360..411f9db07 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -335,3 +335,30 @@ def check_is_good_execution(): """Fixture to seed torch""" return check_is_good_execution_impl + + +def check_array_equality_impl(actual: Any, expected: Any, verbose: bool = True): + """Assert that `actual` is equal to `expected`.""" + + assert numpy.array_equal(actual, expected), ( + "" + if not verbose + else f""" + +Expected Output +=============== +{expected} + +Actual Output +============= +{actual} + + """ + ) + + +@pytest.fixture +def check_array_equality(): + """Fixture to check array equality""" + + return check_array_equality_impl diff --git a/tests/numpy/test_compile.py b/tests/numpy/test_compile.py index b88eb47d5..69ee69f84 100644 --- a/tests/numpy/test_compile.py +++ b/tests/numpy/test_compile.py @@ -1096,6 +1096,7 @@ def test_compile_and_run_tensor_correctness( use_check_good_exec, default_compilation_configuration, check_is_good_execution, + check_array_equality, ): """Test correctness of results when running a compiled function with tensor operators""" circuit = compile_numpy_function( @@ -1113,7 +1114,7 @@ def test_compile_and_run_tensor_correctness( if use_check_good_exec: check_is_good_execution(circuit, function, numpy_test_input) else: - assert numpy.array_equal( + check_array_equality( circuit.run(*numpy_test_input), numpy.array(function(*numpy_test_input), dtype=numpy.uint8), ) @@ -1336,7 +1337,7 @@ def test_compile_and_run_constant_dot_correctness( ], ) def test_compile_and_run_matmul_correctness( - lhs_shape, rhs_shape, input_range, default_compilation_configuration + lhs_shape, rhs_shape, input_range, default_compilation_configuration, check_array_equality ): """Test correctness of results when running a compiled function""" @@ -1371,8 +1372,8 @@ def test_compile_and_run_matmul_correctness( ) args = (numpy.random.randint(low, high + 1, size=lhs_shape, dtype=numpy.uint8),) - assert numpy.array_equal(operator_circuit.run(*args), using_operator(*args)) - assert numpy.array_equal(function_circuit.run(*args), using_function(*args)) + check_array_equality(operator_circuit.run(*args), using_operator(*args)) + check_array_equality(function_circuit.run(*args), using_function(*args)) @pytest.mark.parametrize( diff --git a/tests/numpy/test_compile_constant_indexing.py b/tests/numpy/test_compile_constant_indexing.py index 99698be2a..606922afa 100644 --- a/tests/numpy/test_compile_constant_indexing.py +++ b/tests/numpy/test_compile_constant_indexing.py @@ -709,6 +709,7 @@ def test_constant_indexing_run_correctness( test_input, expected_output, default_compilation_configuration, + check_array_equality, ): """Test correctness of results when running a compiled function with tensor operators""" circuit = compile_numpy_function( @@ -725,19 +726,7 @@ def test_constant_indexing_run_correctness( output = circuit.run(*numpy_test_input) expected = np.array(expected_output, dtype=np.uint8) - assert np.array_equal( - output, expected - ), f""" - -Actual Output -============= -{output} - -Expected Output -=============== -{expected} - - """ + check_array_equality(output, expected) @pytest.mark.parametrize( diff --git a/tests/numpy/test_compile_user_friendly_api.py b/tests/numpy/test_compile_user_friendly_api.py index 10061313b..bdcc7d9ed 100644 --- a/tests/numpy/test_compile_user_friendly_api.py +++ b/tests/numpy/test_compile_user_friendly_api.py @@ -22,13 +22,21 @@ def complicated_topology(x, y): @pytest.mark.parametrize("input_shape", [(), (3, 1, 2)]) -def test_np_fhe_compiler_op_graph(input_shape, default_compilation_configuration): +def test_np_fhe_compiler_op_graph( + input_shape, default_compilation_configuration, check_array_equality +): """Test NPFHECompiler in two subtests.""" - subtest_np_fhe_compiler_1_input_op_graph(input_shape, default_compilation_configuration) - subtest_np_fhe_compiler_2_inputs_op_graph(input_shape, default_compilation_configuration) + subtest_np_fhe_compiler_1_input_op_graph( + input_shape, default_compilation_configuration, check_array_equality + ) + subtest_np_fhe_compiler_2_inputs_op_graph( + input_shape, default_compilation_configuration, check_array_equality + ) -def subtest_np_fhe_compiler_1_input_op_graph(input_shape, default_compilation_configuration): +def subtest_np_fhe_compiler_1_input_op_graph( + input_shape, default_compilation_configuration, check_array_equality +): """test for NPFHECompiler on one input function""" def function_to_compile(x): @@ -48,7 +56,7 @@ def subtest_np_fhe_compiler_1_input_op_graph(input_shape, default_compilation_co for i in numpy.arange(5): i = numpy.ones(input_shape, dtype=numpy.int64) * i - assert numpy.array_equal(compiler(i), function_to_compile(i)) + check_array_equality(compiler(i), function_to_compile(i)) # For coverage, check that we flush the inputset when we query the OPGraph current_op_graph = compiler.op_graph @@ -60,7 +68,7 @@ def subtest_np_fhe_compiler_1_input_op_graph(input_shape, default_compilation_co # Continue a bit more for i in numpy.arange(5, 10): i = numpy.ones(input_shape, dtype=numpy.int64) * i - assert numpy.array_equal(compiler(i), function_to_compile(i)) + check_array_equality(compiler(i), function_to_compile(i)) if input_shape == (): assert ( @@ -106,7 +114,9 @@ def subtest_np_fhe_compiler_1_input_op_graph(input_shape, default_compilation_co ), got -def subtest_np_fhe_compiler_2_inputs_op_graph(input_shape, default_compilation_configuration): +def subtest_np_fhe_compiler_2_inputs_op_graph( + input_shape, default_compilation_configuration, check_array_equality +): """test for NPFHECompiler on two inputs function""" compiler = NPFHECompiler( @@ -124,7 +134,7 @@ def subtest_np_fhe_compiler_2_inputs_op_graph(input_shape, default_compilation_c for i, j in zip(numpy.arange(5), numpy.arange(5, 10)): i = numpy.ones(input_shape, dtype=numpy.int64) * i j = numpy.ones(input_shape, dtype=numpy.int64) * j - assert numpy.array_equal(compiler(i, j), complicated_topology(i, j)) + check_array_equality(compiler(i, j), complicated_topology(i, j)) # For coverage, check that we flush the inputset when we query the OPGraph current_op_graph = compiler.op_graph @@ -137,7 +147,7 @@ def subtest_np_fhe_compiler_2_inputs_op_graph(input_shape, default_compilation_c for i, j in zip(numpy.arange(5, 10), numpy.arange(5)): i = numpy.ones(input_shape, dtype=numpy.int64) * i j = numpy.ones(input_shape, dtype=numpy.int64) * j - assert numpy.array_equal(compiler(i, j), complicated_topology(i, j)) + check_array_equality(compiler(i, j), complicated_topology(i, j)) if input_shape == (): assert ( @@ -200,6 +210,7 @@ def test_np_fhe_compiler_auto_flush( inputset_len, expected_remaining_inputset_len, default_compilation_configuration, + check_array_equality, ): """Test the auto flush of NPFHECompiler once the inputset is 128 elements.""" @@ -213,7 +224,7 @@ def test_np_fhe_compiler_auto_flush( ) for i in numpy.arange(inputset_len): - assert numpy.array_equal(compiler(i), function_to_compile(i)) + check_array_equality(compiler(i), function_to_compile(i)) # Check the inputset was properly flushed assert ( @@ -222,7 +233,7 @@ def test_np_fhe_compiler_auto_flush( ) -def test_np_fhe_compiler_full_compilation(default_compilation_configuration): +def test_np_fhe_compiler_full_compilation(default_compilation_configuration, check_array_equality): """Test the case where we generate an FHE circuit.""" def function_to_compile(x): @@ -244,7 +255,7 @@ def test_np_fhe_compiler_full_compilation(default_compilation_configuration): ) for i in numpy.arange(64): - assert numpy.array_equal(compiler(i), function_to_compile(i)) + check_array_equality(compiler(i), function_to_compile(i)) fhe_circuit = compiler.get_compiled_fhe_circuit() diff --git a/tests/numpy/test_tracing.py b/tests/numpy/test_tracing.py index 32403c6d5..69c8e907e 100644 --- a/tests/numpy/test_tracing.py +++ b/tests/numpy/test_tracing.py @@ -327,7 +327,7 @@ def test_tracing_astype( assert expected_output == evaluated_output -def test_tracing_astype_single_element_array_corner_case(): +def test_tracing_astype_single_element_array_corner_case(check_array_equality): """Test corner case where an array could be transformed to its scalar element""" a = numpy.array([1], dtype=numpy.float64) @@ -336,7 +336,7 @@ def test_tracing_astype_single_element_array_corner_case(): ) eval_result = op_graph(a) - assert numpy.array_equal(numpy.array([1], dtype=numpy.int32), eval_result) + check_array_equality(eval_result, numpy.array([1], dtype=numpy.int32)) @pytest.mark.parametrize( @@ -492,6 +492,7 @@ def test_nptracer_unsupported_operands(operation, exception_type, match): def subtest_tracing_calls( function_to_trace, input_value_input_and_expected_output_tuples, + check_array_equality, ): """Test memory function managed by GenericFunction node of the form numpy.something""" for input_value, input_, expected_output in input_value_input_and_expected_output_tuples: @@ -502,11 +503,7 @@ def subtest_tracing_calls( node_results = op_graph.evaluate({0: input_}) evaluated_output = node_results[output_node] assert isinstance(evaluated_output, type(expected_output)), type(evaluated_output) - if not numpy.array_equal(expected_output, evaluated_output): - print("Wrong result") - print(f"Expected: {expected_output}") - print(f"Got : {evaluated_output}") - raise AssertionError + check_array_equality(evaluated_output, expected_output) @pytest.mark.parametrize( @@ -582,9 +579,12 @@ def subtest_tracing_calls( def test_tracing_numpy_calls( function_to_trace, input_value_input_and_expected_output_tuples, + check_array_equality, ): """Test memory function managed by GenericFunction node of the form numpy.something""" - subtest_tracing_calls(function_to_trace, input_value_input_and_expected_output_tuples) + subtest_tracing_calls( + function_to_trace, input_value_input_and_expected_output_tuples, check_array_equality + ) @pytest.mark.parametrize( @@ -964,9 +964,12 @@ def test_tracing_numpy_calls( def test_tracing_ndarray_calls( function_to_trace, input_value_input_and_expected_output_tuples, + check_array_equality, ): """Test memory function managed by GenericFunction node of the form ndarray.something""" - subtest_tracing_calls(function_to_trace, input_value_input_and_expected_output_tuples) + subtest_tracing_calls( + function_to_trace, input_value_input_and_expected_output_tuples, check_array_equality + ) @pytest.mark.parametrize( diff --git a/tests/numpy/test_tracing_calls.py b/tests/numpy/test_tracing_calls.py index 95ea588db..3e78fb8fa 100644 --- a/tests/numpy/test_tracing_calls.py +++ b/tests/numpy/test_tracing_calls.py @@ -142,6 +142,7 @@ def test_nptracer_get_tracing_func_for_np_functions(np_function): def subtest_tracing_calls( function_to_trace, input_value_input_and_expected_output_tuples, + check_array_equality, ): """Test memory function managed by GenericFunction node of the form numpy.something""" for input_value, input_, expected_output in input_value_input_and_expected_output_tuples: @@ -152,7 +153,7 @@ def subtest_tracing_calls( node_results = op_graph.evaluate({0: input_}) evaluated_output = node_results[output_node] assert isinstance(evaluated_output, type(expected_output)), type(evaluated_output) - assert numpy.array_equal(expected_output, evaluated_output) + check_array_equality(evaluated_output, expected_output) @pytest.mark.parametrize( @@ -228,9 +229,12 @@ def subtest_tracing_calls( def test_tracing_numpy_calls( function_to_trace, input_value_input_and_expected_output_tuples, + check_array_equality, ): """Test memory function managed by GenericFunction node of the form numpy.something""" - subtest_tracing_calls(function_to_trace, input_value_input_and_expected_output_tuples) + subtest_tracing_calls( + function_to_trace, input_value_input_and_expected_output_tuples, check_array_equality + ) @pytest.mark.parametrize( @@ -297,6 +301,9 @@ def test_tracing_numpy_calls( def test_tracing_ndarray_calls( function_to_trace, input_value_input_and_expected_output_tuples, + check_array_equality, ): """Test memory function managed by GenericFunction node of the form ndarray.something""" - subtest_tracing_calls(function_to_trace, input_value_input_and_expected_output_tuples) + subtest_tracing_calls( + function_to_trace, input_value_input_and_expected_output_tuples, check_array_equality + ) diff --git a/tests/quantization/test_quantized_array.py b/tests/quantization/test_quantized_array.py index be324a5b8..cf22bc4be 100644 --- a/tests/quantization/test_quantized_array.py +++ b/tests/quantization/test_quantized_array.py @@ -20,7 +20,7 @@ N_BITS_ATOL_TUPLE_LIST = [ ) @pytest.mark.parametrize("is_signed", [pytest.param(True), pytest.param(False)]) @pytest.mark.parametrize("values", [pytest.param(numpy.random.randn(2000))]) -def test_quant_dequant_update(values, n_bits, atol, is_signed): +def test_quant_dequant_update(values, n_bits, atol, is_signed, check_array_equality): """Test the quant and dequant function.""" quant_array = QuantizedArray(n_bits, values, is_signed) @@ -51,4 +51,4 @@ def test_quant_dequant_update(values, n_bits, atol, is_signed): assert not numpy.array_equal(new_values, new_values_updated) # Check that the __call__ returns also the qvalues. - assert numpy.array_equal(quant_array(), new_qvalues) + check_array_equality(quant_array(), new_qvalues)