mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
test: create check_array_equality fixture
This commit is contained in:
@@ -378,6 +378,7 @@ def test_fuse_float_operations(
|
||||
warning_message,
|
||||
capfd,
|
||||
remove_color_codes,
|
||||
check_array_equality,
|
||||
):
|
||||
"""Test function for fuse_float_operations"""
|
||||
|
||||
@@ -405,7 +406,7 @@ def test_fuse_float_operations(
|
||||
input_ = numpy.full(param_input_value.shape, input_, dtype=numpy.int32)
|
||||
inputs += (input_,)
|
||||
|
||||
assert numpy.array_equal(function_to_trace(*inputs), op_graph(*inputs))
|
||||
check_array_equality(function_to_trace(*inputs), op_graph(*inputs))
|
||||
|
||||
|
||||
def subtest_tensor_no_fuse(fun, tensor_shape):
|
||||
|
||||
@@ -220,10 +220,11 @@ def test_evaluate(
|
||||
node: ir.IntermediateNode,
|
||||
input_data,
|
||||
expected_result: int,
|
||||
check_array_equality,
|
||||
):
|
||||
"""Test evaluate methods on IntermediateNodes"""
|
||||
if isinstance(expected_result, numpy.ndarray):
|
||||
assert numpy.array_equal(node.evaluate(input_data), expected_result)
|
||||
check_array_equality(node.evaluate(input_data), expected_result)
|
||||
else:
|
||||
assert node.evaluate(input_data) == expected_result
|
||||
|
||||
|
||||
@@ -335,3 +335,30 @@ def check_is_good_execution():
|
||||
"""Fixture to seed torch"""
|
||||
|
||||
return check_is_good_execution_impl
|
||||
|
||||
|
||||
def check_array_equality_impl(actual: Any, expected: Any, verbose: bool = True):
|
||||
"""Assert that `actual` is equal to `expected`."""
|
||||
|
||||
assert numpy.array_equal(actual, expected), (
|
||||
""
|
||||
if not verbose
|
||||
else f"""
|
||||
|
||||
Expected Output
|
||||
===============
|
||||
{expected}
|
||||
|
||||
Actual Output
|
||||
=============
|
||||
{actual}
|
||||
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def check_array_equality():
|
||||
"""Fixture to check array equality"""
|
||||
|
||||
return check_array_equality_impl
|
||||
|
||||
@@ -1096,6 +1096,7 @@ def test_compile_and_run_tensor_correctness(
|
||||
use_check_good_exec,
|
||||
default_compilation_configuration,
|
||||
check_is_good_execution,
|
||||
check_array_equality,
|
||||
):
|
||||
"""Test correctness of results when running a compiled function with tensor operators"""
|
||||
circuit = compile_numpy_function(
|
||||
@@ -1113,7 +1114,7 @@ def test_compile_and_run_tensor_correctness(
|
||||
if use_check_good_exec:
|
||||
check_is_good_execution(circuit, function, numpy_test_input)
|
||||
else:
|
||||
assert numpy.array_equal(
|
||||
check_array_equality(
|
||||
circuit.run(*numpy_test_input),
|
||||
numpy.array(function(*numpy_test_input), dtype=numpy.uint8),
|
||||
)
|
||||
@@ -1336,7 +1337,7 @@ def test_compile_and_run_constant_dot_correctness(
|
||||
],
|
||||
)
|
||||
def test_compile_and_run_matmul_correctness(
|
||||
lhs_shape, rhs_shape, input_range, default_compilation_configuration
|
||||
lhs_shape, rhs_shape, input_range, default_compilation_configuration, check_array_equality
|
||||
):
|
||||
"""Test correctness of results when running a compiled function"""
|
||||
|
||||
@@ -1371,8 +1372,8 @@ def test_compile_and_run_matmul_correctness(
|
||||
)
|
||||
|
||||
args = (numpy.random.randint(low, high + 1, size=lhs_shape, dtype=numpy.uint8),)
|
||||
assert numpy.array_equal(operator_circuit.run(*args), using_operator(*args))
|
||||
assert numpy.array_equal(function_circuit.run(*args), using_function(*args))
|
||||
check_array_equality(operator_circuit.run(*args), using_operator(*args))
|
||||
check_array_equality(function_circuit.run(*args), using_function(*args))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
@@ -709,6 +709,7 @@ def test_constant_indexing_run_correctness(
|
||||
test_input,
|
||||
expected_output,
|
||||
default_compilation_configuration,
|
||||
check_array_equality,
|
||||
):
|
||||
"""Test correctness of results when running a compiled function with tensor operators"""
|
||||
circuit = compile_numpy_function(
|
||||
@@ -725,19 +726,7 @@ def test_constant_indexing_run_correctness(
|
||||
output = circuit.run(*numpy_test_input)
|
||||
expected = np.array(expected_output, dtype=np.uint8)
|
||||
|
||||
assert np.array_equal(
|
||||
output, expected
|
||||
), f"""
|
||||
|
||||
Actual Output
|
||||
=============
|
||||
{output}
|
||||
|
||||
Expected Output
|
||||
===============
|
||||
{expected}
|
||||
|
||||
"""
|
||||
check_array_equality(output, expected)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
@@ -22,13 +22,21 @@ def complicated_topology(x, y):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("input_shape", [(), (3, 1, 2)])
|
||||
def test_np_fhe_compiler_op_graph(input_shape, default_compilation_configuration):
|
||||
def test_np_fhe_compiler_op_graph(
|
||||
input_shape, default_compilation_configuration, check_array_equality
|
||||
):
|
||||
"""Test NPFHECompiler in two subtests."""
|
||||
subtest_np_fhe_compiler_1_input_op_graph(input_shape, default_compilation_configuration)
|
||||
subtest_np_fhe_compiler_2_inputs_op_graph(input_shape, default_compilation_configuration)
|
||||
subtest_np_fhe_compiler_1_input_op_graph(
|
||||
input_shape, default_compilation_configuration, check_array_equality
|
||||
)
|
||||
subtest_np_fhe_compiler_2_inputs_op_graph(
|
||||
input_shape, default_compilation_configuration, check_array_equality
|
||||
)
|
||||
|
||||
|
||||
def subtest_np_fhe_compiler_1_input_op_graph(input_shape, default_compilation_configuration):
|
||||
def subtest_np_fhe_compiler_1_input_op_graph(
|
||||
input_shape, default_compilation_configuration, check_array_equality
|
||||
):
|
||||
"""test for NPFHECompiler on one input function"""
|
||||
|
||||
def function_to_compile(x):
|
||||
@@ -48,7 +56,7 @@ def subtest_np_fhe_compiler_1_input_op_graph(input_shape, default_compilation_co
|
||||
|
||||
for i in numpy.arange(5):
|
||||
i = numpy.ones(input_shape, dtype=numpy.int64) * i
|
||||
assert numpy.array_equal(compiler(i), function_to_compile(i))
|
||||
check_array_equality(compiler(i), function_to_compile(i))
|
||||
|
||||
# For coverage, check that we flush the inputset when we query the OPGraph
|
||||
current_op_graph = compiler.op_graph
|
||||
@@ -60,7 +68,7 @@ def subtest_np_fhe_compiler_1_input_op_graph(input_shape, default_compilation_co
|
||||
# Continue a bit more
|
||||
for i in numpy.arange(5, 10):
|
||||
i = numpy.ones(input_shape, dtype=numpy.int64) * i
|
||||
assert numpy.array_equal(compiler(i), function_to_compile(i))
|
||||
check_array_equality(compiler(i), function_to_compile(i))
|
||||
|
||||
if input_shape == ():
|
||||
assert (
|
||||
@@ -106,7 +114,9 @@ def subtest_np_fhe_compiler_1_input_op_graph(input_shape, default_compilation_co
|
||||
), got
|
||||
|
||||
|
||||
def subtest_np_fhe_compiler_2_inputs_op_graph(input_shape, default_compilation_configuration):
|
||||
def subtest_np_fhe_compiler_2_inputs_op_graph(
|
||||
input_shape, default_compilation_configuration, check_array_equality
|
||||
):
|
||||
"""test for NPFHECompiler on two inputs function"""
|
||||
|
||||
compiler = NPFHECompiler(
|
||||
@@ -124,7 +134,7 @@ def subtest_np_fhe_compiler_2_inputs_op_graph(input_shape, default_compilation_c
|
||||
for i, j in zip(numpy.arange(5), numpy.arange(5, 10)):
|
||||
i = numpy.ones(input_shape, dtype=numpy.int64) * i
|
||||
j = numpy.ones(input_shape, dtype=numpy.int64) * j
|
||||
assert numpy.array_equal(compiler(i, j), complicated_topology(i, j))
|
||||
check_array_equality(compiler(i, j), complicated_topology(i, j))
|
||||
|
||||
# For coverage, check that we flush the inputset when we query the OPGraph
|
||||
current_op_graph = compiler.op_graph
|
||||
@@ -137,7 +147,7 @@ def subtest_np_fhe_compiler_2_inputs_op_graph(input_shape, default_compilation_c
|
||||
for i, j in zip(numpy.arange(5, 10), numpy.arange(5)):
|
||||
i = numpy.ones(input_shape, dtype=numpy.int64) * i
|
||||
j = numpy.ones(input_shape, dtype=numpy.int64) * j
|
||||
assert numpy.array_equal(compiler(i, j), complicated_topology(i, j))
|
||||
check_array_equality(compiler(i, j), complicated_topology(i, j))
|
||||
|
||||
if input_shape == ():
|
||||
assert (
|
||||
@@ -200,6 +210,7 @@ def test_np_fhe_compiler_auto_flush(
|
||||
inputset_len,
|
||||
expected_remaining_inputset_len,
|
||||
default_compilation_configuration,
|
||||
check_array_equality,
|
||||
):
|
||||
"""Test the auto flush of NPFHECompiler once the inputset is 128 elements."""
|
||||
|
||||
@@ -213,7 +224,7 @@ def test_np_fhe_compiler_auto_flush(
|
||||
)
|
||||
|
||||
for i in numpy.arange(inputset_len):
|
||||
assert numpy.array_equal(compiler(i), function_to_compile(i))
|
||||
check_array_equality(compiler(i), function_to_compile(i))
|
||||
|
||||
# Check the inputset was properly flushed
|
||||
assert (
|
||||
@@ -222,7 +233,7 @@ def test_np_fhe_compiler_auto_flush(
|
||||
)
|
||||
|
||||
|
||||
def test_np_fhe_compiler_full_compilation(default_compilation_configuration):
|
||||
def test_np_fhe_compiler_full_compilation(default_compilation_configuration, check_array_equality):
|
||||
"""Test the case where we generate an FHE circuit."""
|
||||
|
||||
def function_to_compile(x):
|
||||
@@ -244,7 +255,7 @@ def test_np_fhe_compiler_full_compilation(default_compilation_configuration):
|
||||
)
|
||||
|
||||
for i in numpy.arange(64):
|
||||
assert numpy.array_equal(compiler(i), function_to_compile(i))
|
||||
check_array_equality(compiler(i), function_to_compile(i))
|
||||
|
||||
fhe_circuit = compiler.get_compiled_fhe_circuit()
|
||||
|
||||
|
||||
@@ -327,7 +327,7 @@ def test_tracing_astype(
|
||||
assert expected_output == evaluated_output
|
||||
|
||||
|
||||
def test_tracing_astype_single_element_array_corner_case():
|
||||
def test_tracing_astype_single_element_array_corner_case(check_array_equality):
|
||||
"""Test corner case where an array could be transformed to its scalar element"""
|
||||
a = numpy.array([1], dtype=numpy.float64)
|
||||
|
||||
@@ -336,7 +336,7 @@ def test_tracing_astype_single_element_array_corner_case():
|
||||
)
|
||||
|
||||
eval_result = op_graph(a)
|
||||
assert numpy.array_equal(numpy.array([1], dtype=numpy.int32), eval_result)
|
||||
check_array_equality(eval_result, numpy.array([1], dtype=numpy.int32))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -492,6 +492,7 @@ def test_nptracer_unsupported_operands(operation, exception_type, match):
|
||||
def subtest_tracing_calls(
|
||||
function_to_trace,
|
||||
input_value_input_and_expected_output_tuples,
|
||||
check_array_equality,
|
||||
):
|
||||
"""Test memory function managed by GenericFunction node of the form numpy.something"""
|
||||
for input_value, input_, expected_output in input_value_input_and_expected_output_tuples:
|
||||
@@ -502,11 +503,7 @@ def subtest_tracing_calls(
|
||||
node_results = op_graph.evaluate({0: input_})
|
||||
evaluated_output = node_results[output_node]
|
||||
assert isinstance(evaluated_output, type(expected_output)), type(evaluated_output)
|
||||
if not numpy.array_equal(expected_output, evaluated_output):
|
||||
print("Wrong result")
|
||||
print(f"Expected: {expected_output}")
|
||||
print(f"Got : {evaluated_output}")
|
||||
raise AssertionError
|
||||
check_array_equality(evaluated_output, expected_output)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -582,9 +579,12 @@ def subtest_tracing_calls(
|
||||
def test_tracing_numpy_calls(
|
||||
function_to_trace,
|
||||
input_value_input_and_expected_output_tuples,
|
||||
check_array_equality,
|
||||
):
|
||||
"""Test memory function managed by GenericFunction node of the form numpy.something"""
|
||||
subtest_tracing_calls(function_to_trace, input_value_input_and_expected_output_tuples)
|
||||
subtest_tracing_calls(
|
||||
function_to_trace, input_value_input_and_expected_output_tuples, check_array_equality
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -964,9 +964,12 @@ def test_tracing_numpy_calls(
|
||||
def test_tracing_ndarray_calls(
|
||||
function_to_trace,
|
||||
input_value_input_and_expected_output_tuples,
|
||||
check_array_equality,
|
||||
):
|
||||
"""Test memory function managed by GenericFunction node of the form ndarray.something"""
|
||||
subtest_tracing_calls(function_to_trace, input_value_input_and_expected_output_tuples)
|
||||
subtest_tracing_calls(
|
||||
function_to_trace, input_value_input_and_expected_output_tuples, check_array_equality
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
@@ -142,6 +142,7 @@ def test_nptracer_get_tracing_func_for_np_functions(np_function):
|
||||
def subtest_tracing_calls(
|
||||
function_to_trace,
|
||||
input_value_input_and_expected_output_tuples,
|
||||
check_array_equality,
|
||||
):
|
||||
"""Test memory function managed by GenericFunction node of the form numpy.something"""
|
||||
for input_value, input_, expected_output in input_value_input_and_expected_output_tuples:
|
||||
@@ -152,7 +153,7 @@ def subtest_tracing_calls(
|
||||
node_results = op_graph.evaluate({0: input_})
|
||||
evaluated_output = node_results[output_node]
|
||||
assert isinstance(evaluated_output, type(expected_output)), type(evaluated_output)
|
||||
assert numpy.array_equal(expected_output, evaluated_output)
|
||||
check_array_equality(evaluated_output, expected_output)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -228,9 +229,12 @@ def subtest_tracing_calls(
|
||||
def test_tracing_numpy_calls(
|
||||
function_to_trace,
|
||||
input_value_input_and_expected_output_tuples,
|
||||
check_array_equality,
|
||||
):
|
||||
"""Test memory function managed by GenericFunction node of the form numpy.something"""
|
||||
subtest_tracing_calls(function_to_trace, input_value_input_and_expected_output_tuples)
|
||||
subtest_tracing_calls(
|
||||
function_to_trace, input_value_input_and_expected_output_tuples, check_array_equality
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -297,6 +301,9 @@ def test_tracing_numpy_calls(
|
||||
def test_tracing_ndarray_calls(
|
||||
function_to_trace,
|
||||
input_value_input_and_expected_output_tuples,
|
||||
check_array_equality,
|
||||
):
|
||||
"""Test memory function managed by GenericFunction node of the form ndarray.something"""
|
||||
subtest_tracing_calls(function_to_trace, input_value_input_and_expected_output_tuples)
|
||||
subtest_tracing_calls(
|
||||
function_to_trace, input_value_input_and_expected_output_tuples, check_array_equality
|
||||
)
|
||||
|
||||
@@ -20,7 +20,7 @@ N_BITS_ATOL_TUPLE_LIST = [
|
||||
)
|
||||
@pytest.mark.parametrize("is_signed", [pytest.param(True), pytest.param(False)])
|
||||
@pytest.mark.parametrize("values", [pytest.param(numpy.random.randn(2000))])
|
||||
def test_quant_dequant_update(values, n_bits, atol, is_signed):
|
||||
def test_quant_dequant_update(values, n_bits, atol, is_signed, check_array_equality):
|
||||
"""Test the quant and dequant function."""
|
||||
|
||||
quant_array = QuantizedArray(n_bits, values, is_signed)
|
||||
@@ -51,4 +51,4 @@ def test_quant_dequant_update(values, n_bits, atol, is_signed):
|
||||
assert not numpy.array_equal(new_values, new_values_updated)
|
||||
|
||||
# Check that the __call__ returns also the qvalues.
|
||||
assert numpy.array_equal(quant_array(), new_qvalues)
|
||||
check_array_equality(quant_array(), new_qvalues)
|
||||
|
||||
Reference in New Issue
Block a user