mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
chore: factorize some data_gen.
This commit is contained in:
committed by
Benoit Chevallier
parent
50c1ceb6db
commit
5d31aa4d2c
@@ -18,6 +18,12 @@ from concrete.numpy.compile import compile_numpy_function, compile_numpy_functio
|
||||
# pylint: disable=too-many-lines
|
||||
|
||||
|
||||
def data_gen(args):
|
||||
"""Helper to create an inputset"""
|
||||
for prod in itertools.product(*args):
|
||||
yield prod
|
||||
|
||||
|
||||
def no_fuse_unhandled(x, y):
|
||||
"""No fuse unhandled"""
|
||||
x_intermediate = x + 2.8
|
||||
@@ -259,10 +265,6 @@ def subtest_compile_and_run_unary_ufunc_correctness(
|
||||
|
||||
function = get_function(ufunc, upper_function)
|
||||
|
||||
def data_gen(args):
|
||||
for prod in itertools.product(*args):
|
||||
yield prod
|
||||
|
||||
function_parameters = {arg_name: EncryptedScalar(Integer(64, False)) for arg_name in ["x", "y"]}
|
||||
|
||||
compiler_engine = compile_numpy_function(
|
||||
@@ -291,10 +293,6 @@ def subtest_compile_and_run_binary_ufunc_correctness(
|
||||
|
||||
function = get_function(ufunc, upper_function)
|
||||
|
||||
def data_gen(args):
|
||||
for prod in itertools.product(*args):
|
||||
yield prod
|
||||
|
||||
function_parameters = {arg_name: EncryptedScalar(Integer(64, True)) for arg_name in ["x", "y"]}
|
||||
|
||||
compiler_engine = compile_numpy_function(
|
||||
@@ -476,7 +474,7 @@ def test_compile_function_multiple_outputs(
|
||||
):
|
||||
"""Test function compile_numpy_function_into_op_graph for a program with multiple outputs"""
|
||||
|
||||
def data_gen(args):
|
||||
def data_gen_local(args):
|
||||
for prod in itertools.product(*args):
|
||||
yield tuple(numpy.array(val) for val in prod)
|
||||
|
||||
@@ -487,7 +485,7 @@ def test_compile_function_multiple_outputs(
|
||||
op_graph = compile_numpy_function_into_op_graph(
|
||||
function,
|
||||
function_parameters,
|
||||
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
|
||||
data_gen_local(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
@@ -528,10 +526,6 @@ def test_compile_and_run_correctness(
|
||||
):
|
||||
"""Test correctness of results when running a compiled function"""
|
||||
|
||||
def data_gen(args):
|
||||
for prod in itertools.product(*args):
|
||||
yield prod
|
||||
|
||||
function_parameters = {
|
||||
arg_name: EncryptedScalar(Integer(64, False)) for arg_name in list_of_arg_names
|
||||
}
|
||||
@@ -963,10 +957,6 @@ def test_compile_and_run_lut_correctness(
|
||||
|
||||
input_ranges = tuple((0, 2 ** input_bit - 1) for input_bit in input_bits)
|
||||
|
||||
def data_gen(args):
|
||||
for prod in itertools.product(*args):
|
||||
yield prod
|
||||
|
||||
function_parameters = {
|
||||
arg_name: EncryptedScalar(Integer(input_bit, False))
|
||||
for input_bit, arg_name in zip(input_bits, list_of_arg_names)
|
||||
@@ -1349,7 +1339,7 @@ def test_compile_function_with_dot(
|
||||
|
||||
# This is the exhaust, but if ever we have too long inputs (ie, large 'repeat'),
|
||||
# we'll have to take random values, not all values one by one
|
||||
def data_gen(max_for_ij, repeat):
|
||||
def data_gen_local(max_for_ij, repeat):
|
||||
iter_i = itertools.product(range(0, max_for_ij + 1), repeat=repeat)
|
||||
iter_j = itertools.product(range(0, max_for_ij + 1), repeat=repeat)
|
||||
for prod_i, prod_j in itertools.product(iter_i, iter_j):
|
||||
@@ -1362,7 +1352,7 @@ def test_compile_function_with_dot(
|
||||
op_graph = compile_numpy_function_into_op_graph(
|
||||
function,
|
||||
params,
|
||||
data_gen(max_for_ij, repeat),
|
||||
data_gen_local(max_for_ij, repeat),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
str_of_the_graph = format_operation_graph(op_graph)
|
||||
@@ -1389,10 +1379,6 @@ def test_compile_with_show_mlir(
|
||||
):
|
||||
"""Test show_mlir option"""
|
||||
|
||||
def data_gen(args):
|
||||
for prod in itertools.product(*args):
|
||||
yield prod
|
||||
|
||||
function_parameters = {
|
||||
arg_name: EncryptedScalar(Integer(64, False)) for arg_name in list_of_arg_names
|
||||
}
|
||||
@@ -1412,10 +1398,6 @@ def test_compile_too_high_bitwidth(default_compilation_configuration):
|
||||
def function(x, y):
|
||||
return x + y
|
||||
|
||||
def data_gen(args):
|
||||
for prod in itertools.product(*args):
|
||||
yield prod
|
||||
|
||||
function_parameters = {
|
||||
"x": EncryptedScalar(Integer(64, False)),
|
||||
"y": EncryptedScalar(Integer(64, False)),
|
||||
@@ -1517,10 +1499,6 @@ def test_fail_compile_with_random_inputset(default_compilation_configuration):
|
||||
def test_wrong_inputs(default_compilation_configuration):
|
||||
"""Test compilation with faulty inputs"""
|
||||
|
||||
def data_gen(args):
|
||||
for prod in itertools.product(*args):
|
||||
yield prod
|
||||
|
||||
# x should have been something like EncryptedScalar(UnsignedInteger(3))
|
||||
x = [1, 2, 3]
|
||||
input_ranges = ((0, 10),)
|
||||
@@ -1576,10 +1554,6 @@ def test_compile_and_run_correctness_with_negative_values(
|
||||
"""Test correctness of results when running a compiled function, which has some negative
|
||||
intermediate values."""
|
||||
|
||||
def data_gen(args):
|
||||
for prod in itertools.product(*args):
|
||||
yield prod
|
||||
|
||||
function_parameters = {
|
||||
arg_name: EncryptedScalar(Integer(64, False)) for arg_name in list_of_arg_names
|
||||
}
|
||||
@@ -1614,10 +1588,6 @@ def test_compile_and_run_correctness_with_negative_results(
|
||||
results are currently only correct modulo a power of 2 (given by `modulus` parameter). Eg,
|
||||
instead of returning -3, the execution may return -3 mod 128 = 125."""
|
||||
|
||||
def data_gen(args):
|
||||
for prod in itertools.product(*args):
|
||||
yield prod
|
||||
|
||||
function_parameters = {
|
||||
arg_name: EncryptedScalar(Integer(64, False)) for arg_name in list_of_arg_names
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user