mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
committed by
Benoit Chevallier
parent
1394dd6db5
commit
e7e7a02425
@@ -1513,3 +1513,54 @@ def test_wrong_inputs(default_compilation_configuration):
|
||||
str(excinfo.value) == f"wrong type for inputs {dict_for_inputs}, "
|
||||
f"needs to be one of {list_of_possible_basevalue}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,input_ranges,list_of_arg_names",
|
||||
[
|
||||
pytest.param(lambda x: (x + (-27)) + 32, ((0, 10),), ["x"]),
|
||||
pytest.param(lambda x: ((-3) * x) + (100 - (x + 1)), ((0, 10),), ["x"]),
|
||||
# FIXME: doesn't work for now, #885
|
||||
# pytest.param(
|
||||
# lambda x: (50 * (numpy.cos(x + 33.0))).astype(numpy.uint32),
|
||||
# ((0, 31),),
|
||||
# ["x"],
|
||||
# ),
|
||||
pytest.param(
|
||||
lambda x: (20 * (numpy.cos(x + 33.0)) + 30).astype(numpy.uint32),
|
||||
((0, 31),),
|
||||
["x"],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: (-1) * x + (-2) * y + 40,
|
||||
(
|
||||
(0, 10),
|
||||
(0, 10),
|
||||
),
|
||||
["x", "y"],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_compile_and_run_correctness_with_negative_values(
|
||||
function, input_ranges, list_of_arg_names, default_compilation_configuration
|
||||
):
|
||||
"""Test correctness of results when running a compiled function, which has some negative
|
||||
intermediate values."""
|
||||
|
||||
def data_gen(args):
|
||||
for prod in itertools.product(*args):
|
||||
yield prod
|
||||
|
||||
function_parameters = {
|
||||
arg_name: EncryptedScalar(Integer(64, False)) for arg_name in list_of_arg_names
|
||||
}
|
||||
|
||||
compiler_engine = compile_numpy_function(
|
||||
function,
|
||||
function_parameters,
|
||||
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
args = [random.randint(low, high) for (low, high) in input_ranges]
|
||||
assert compiler_engine.run(*args) == function(*args)
|
||||
|
||||
Reference in New Issue
Block a user