diff --git a/tests/numpy/test_compile.py b/tests/numpy/test_compile.py index 4ef224683..c62c004f7 100644 --- a/tests/numpy/test_compile.py +++ b/tests/numpy/test_compile.py @@ -18,6 +18,12 @@ from concrete.numpy.compile import compile_numpy_function, compile_numpy_functio # pylint: disable=too-many-lines +def data_gen(args): + """Helper to create an inputset""" + for prod in itertools.product(*args): + yield prod + + def no_fuse_unhandled(x, y): """No fuse unhandled""" x_intermediate = x + 2.8 @@ -259,10 +265,6 @@ def subtest_compile_and_run_unary_ufunc_correctness( function = get_function(ufunc, upper_function) - def data_gen(args): - for prod in itertools.product(*args): - yield prod - function_parameters = {arg_name: EncryptedScalar(Integer(64, False)) for arg_name in ["x", "y"]} compiler_engine = compile_numpy_function( @@ -291,10 +293,6 @@ def subtest_compile_and_run_binary_ufunc_correctness( function = get_function(ufunc, upper_function) - def data_gen(args): - for prod in itertools.product(*args): - yield prod - function_parameters = {arg_name: EncryptedScalar(Integer(64, True)) for arg_name in ["x", "y"]} compiler_engine = compile_numpy_function( @@ -476,7 +474,7 @@ def test_compile_function_multiple_outputs( ): """Test function compile_numpy_function_into_op_graph for a program with multiple outputs""" - def data_gen(args): + def data_gen_local(args): for prod in itertools.product(*args): yield tuple(numpy.array(val) for val in prod) @@ -487,7 +485,7 @@ def test_compile_function_multiple_outputs( op_graph = compile_numpy_function_into_op_graph( function, function_parameters, - data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)), + data_gen_local(tuple(range(x[0], x[1] + 1) for x in input_ranges)), default_compilation_configuration, ) @@ -528,10 +526,6 @@ def test_compile_and_run_correctness( ): """Test correctness of results when running a compiled function""" - def data_gen(args): - for prod in itertools.product(*args): - yield prod - function_parameters = { arg_name: EncryptedScalar(Integer(64, False)) for arg_name in list_of_arg_names } @@ -963,10 +957,6 @@ def test_compile_and_run_lut_correctness( input_ranges = tuple((0, 2 ** input_bit - 1) for input_bit in input_bits) - def data_gen(args): - for prod in itertools.product(*args): - yield prod - function_parameters = { arg_name: EncryptedScalar(Integer(input_bit, False)) for input_bit, arg_name in zip(input_bits, list_of_arg_names) @@ -1349,7 +1339,7 @@ def test_compile_function_with_dot( # This is the exhaust, but if ever we have too long inputs (ie, large 'repeat'), # we'll have to take random values, not all values one by one - def data_gen(max_for_ij, repeat): + def data_gen_local(max_for_ij, repeat): iter_i = itertools.product(range(0, max_for_ij + 1), repeat=repeat) iter_j = itertools.product(range(0, max_for_ij + 1), repeat=repeat) for prod_i, prod_j in itertools.product(iter_i, iter_j): @@ -1362,7 +1352,7 @@ def test_compile_function_with_dot( op_graph = compile_numpy_function_into_op_graph( function, params, - data_gen(max_for_ij, repeat), + data_gen_local(max_for_ij, repeat), default_compilation_configuration, ) str_of_the_graph = format_operation_graph(op_graph) @@ -1389,10 +1379,6 @@ def test_compile_with_show_mlir( ): """Test show_mlir option""" - def data_gen(args): - for prod in itertools.product(*args): - yield prod - function_parameters = { arg_name: EncryptedScalar(Integer(64, False)) for arg_name in list_of_arg_names } @@ -1412,10 +1398,6 @@ def test_compile_too_high_bitwidth(default_compilation_configuration): def function(x, y): return x + y - def data_gen(args): - for prod in itertools.product(*args): - yield prod - function_parameters = { "x": EncryptedScalar(Integer(64, False)), "y": EncryptedScalar(Integer(64, False)), @@ -1517,10 +1499,6 @@ def test_fail_compile_with_random_inputset(default_compilation_configuration): def test_wrong_inputs(default_compilation_configuration): """Test compilation with faulty inputs""" - def data_gen(args): - for prod in itertools.product(*args): - yield prod - # x should have been something like EncryptedScalar(UnsignedInteger(3)) x = [1, 2, 3] input_ranges = ((0, 10),) @@ -1576,10 +1554,6 @@ def test_compile_and_run_correctness_with_negative_values( """Test correctness of results when running a compiled function, which has some negative intermediate values.""" - def data_gen(args): - for prod in itertools.product(*args): - yield prod - function_parameters = { arg_name: EncryptedScalar(Integer(64, False)) for arg_name in list_of_arg_names } @@ -1614,10 +1588,6 @@ def test_compile_and_run_correctness_with_negative_results( results are currently only correct modulo a power of 2 (given by `modulus` parameter). Eg, instead of returning -3, the execution may return -3 mod 128 = 125.""" - def data_gen(args): - for prod in itertools.product(*args): - yield prod - function_parameters = { arg_name: EncryptedScalar(Integer(64, False)) for arg_name in list_of_arg_names }