chore: factorize some data_gen.

This commit is contained in:
Benoit Chevallier-Mames
2021-11-15 10:55:46 +01:00
committed by Benoit Chevallier
parent 50c1ceb6db
commit 5d31aa4d2c

View File

@@ -18,6 +18,12 @@ from concrete.numpy.compile import compile_numpy_function, compile_numpy_functio
# pylint: disable=too-many-lines
def data_gen(args):
"""Helper to create an inputset"""
for prod in itertools.product(*args):
yield prod
def no_fuse_unhandled(x, y):
"""No fuse unhandled"""
x_intermediate = x + 2.8
@@ -259,10 +265,6 @@ def subtest_compile_and_run_unary_ufunc_correctness(
function = get_function(ufunc, upper_function)
def data_gen(args):
for prod in itertools.product(*args):
yield prod
function_parameters = {arg_name: EncryptedScalar(Integer(64, False)) for arg_name in ["x", "y"]}
compiler_engine = compile_numpy_function(
@@ -291,10 +293,6 @@ def subtest_compile_and_run_binary_ufunc_correctness(
function = get_function(ufunc, upper_function)
def data_gen(args):
for prod in itertools.product(*args):
yield prod
function_parameters = {arg_name: EncryptedScalar(Integer(64, True)) for arg_name in ["x", "y"]}
compiler_engine = compile_numpy_function(
@@ -476,7 +474,7 @@ def test_compile_function_multiple_outputs(
):
"""Test function compile_numpy_function_into_op_graph for a program with multiple outputs"""
def data_gen(args):
def data_gen_local(args):
for prod in itertools.product(*args):
yield tuple(numpy.array(val) for val in prod)
@@ -487,7 +485,7 @@ def test_compile_function_multiple_outputs(
op_graph = compile_numpy_function_into_op_graph(
function,
function_parameters,
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
data_gen_local(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
default_compilation_configuration,
)
@@ -528,10 +526,6 @@ def test_compile_and_run_correctness(
):
"""Test correctness of results when running a compiled function"""
def data_gen(args):
for prod in itertools.product(*args):
yield prod
function_parameters = {
arg_name: EncryptedScalar(Integer(64, False)) for arg_name in list_of_arg_names
}
@@ -963,10 +957,6 @@ def test_compile_and_run_lut_correctness(
input_ranges = tuple((0, 2 ** input_bit - 1) for input_bit in input_bits)
def data_gen(args):
for prod in itertools.product(*args):
yield prod
function_parameters = {
arg_name: EncryptedScalar(Integer(input_bit, False))
for input_bit, arg_name in zip(input_bits, list_of_arg_names)
@@ -1349,7 +1339,7 @@ def test_compile_function_with_dot(
# This is the exhaust, but if ever we have too long inputs (ie, large 'repeat'),
# we'll have to take random values, not all values one by one
def data_gen(max_for_ij, repeat):
def data_gen_local(max_for_ij, repeat):
iter_i = itertools.product(range(0, max_for_ij + 1), repeat=repeat)
iter_j = itertools.product(range(0, max_for_ij + 1), repeat=repeat)
for prod_i, prod_j in itertools.product(iter_i, iter_j):
@@ -1362,7 +1352,7 @@ def test_compile_function_with_dot(
op_graph = compile_numpy_function_into_op_graph(
function,
params,
data_gen(max_for_ij, repeat),
data_gen_local(max_for_ij, repeat),
default_compilation_configuration,
)
str_of_the_graph = format_operation_graph(op_graph)
@@ -1389,10 +1379,6 @@ def test_compile_with_show_mlir(
):
"""Test show_mlir option"""
def data_gen(args):
for prod in itertools.product(*args):
yield prod
function_parameters = {
arg_name: EncryptedScalar(Integer(64, False)) for arg_name in list_of_arg_names
}
@@ -1412,10 +1398,6 @@ def test_compile_too_high_bitwidth(default_compilation_configuration):
def function(x, y):
return x + y
def data_gen(args):
for prod in itertools.product(*args):
yield prod
function_parameters = {
"x": EncryptedScalar(Integer(64, False)),
"y": EncryptedScalar(Integer(64, False)),
@@ -1517,10 +1499,6 @@ def test_fail_compile_with_random_inputset(default_compilation_configuration):
def test_wrong_inputs(default_compilation_configuration):
"""Test compilation with faulty inputs"""
def data_gen(args):
for prod in itertools.product(*args):
yield prod
# x should have been something like EncryptedScalar(UnsignedInteger(3))
x = [1, 2, 3]
input_ranges = ((0, 10),)
@@ -1576,10 +1554,6 @@ def test_compile_and_run_correctness_with_negative_values(
"""Test correctness of results when running a compiled function, which has some negative
intermediate values."""
def data_gen(args):
for prod in itertools.product(*args):
yield prod
function_parameters = {
arg_name: EncryptedScalar(Integer(64, False)) for arg_name in list_of_arg_names
}
@@ -1614,10 +1588,6 @@ def test_compile_and_run_correctness_with_negative_results(
results are currently only correct modulo a power of 2 (given by `modulus` parameter). Eg,
instead of returning -3, the execution may return -3 mod 128 = 125."""
def data_gen(args):
for prod in itertools.product(*args):
yield prod
function_parameters = {
arg_name: EncryptedScalar(Integer(64, False)) for arg_name in list_of_arg_names
}