mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
test(execution): add tests with negative results
for now, results are correct mod 128, but wrong without this modulo reduction. Will be later fixed by the compiler. closes #844 refs #845
This commit is contained in:
committed by
Benoit Chevallier
parent
239f66eb46
commit
50c1ceb6db
@@ -115,8 +115,14 @@ def check_node_compatibility_with_mlir(
|
||||
|
||||
if is_output:
|
||||
for out in outputs:
|
||||
if not value_is_unsigned_integer(out):
|
||||
return "only unsigned integer outputs are supported"
|
||||
# For signed values and waiting for a real fix (#845): what is returned by the compiler
|
||||
# is not the (possibly negative) result r, but the always-positive (r mod 2**t), where t
|
||||
# is the bitwidth of r
|
||||
|
||||
# We currently can't fail on the following assert, but let it for possible changes in
|
||||
# the future
|
||||
if not value_is_integer(out):
|
||||
return "only integer outputs are supported" # pragma: no cover
|
||||
else:
|
||||
for out in outputs:
|
||||
# We currently can't fail on the following assert, but let it for possible changes in
|
||||
|
||||
@@ -1033,24 +1033,6 @@ def test_compile_function_with_direct_tlu_overflow(default_compilation_configura
|
||||
@pytest.mark.parametrize(
|
||||
"function,parameters,inputset,match",
|
||||
[
|
||||
pytest.param(
|
||||
lambda x: 1 - x,
|
||||
{"x": EncryptedScalar(Integer(3, is_signed=False))},
|
||||
[(i,) for i in range(8)],
|
||||
(
|
||||
"""
|
||||
|
||||
function you are trying to compile isn't supported for MLIR lowering
|
||||
|
||||
%0 = 1 # ClearScalar<uint1>
|
||||
%1 = x # EncryptedScalar<uint3>
|
||||
%2 = sub(%0, %1) # EncryptedScalar<int4>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer outputs are supported
|
||||
return %2
|
||||
|
||||
""".strip() # noqa: E501
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: numpy.dot(x, y),
|
||||
{
|
||||
@@ -1476,44 +1458,6 @@ return %2
|
||||
)
|
||||
|
||||
|
||||
def test_failure_for_signed_output(default_compilation_configuration):
|
||||
"""Test that we don't accept signed output"""
|
||||
function = lambda x: x + (-3) # noqa: E731
|
||||
input_ranges = ((0, 10),)
|
||||
list_of_arg_names = ["x"]
|
||||
|
||||
def data_gen(args):
|
||||
for prod in itertools.product(*args):
|
||||
yield prod
|
||||
|
||||
function_parameters = {
|
||||
arg_name: EncryptedScalar(Integer(64, False)) for arg_name in list_of_arg_names
|
||||
}
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
compile_numpy_function(
|
||||
function,
|
||||
function_parameters,
|
||||
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
assert (
|
||||
str(excinfo.value)
|
||||
== """
|
||||
|
||||
function you are trying to compile isn't supported for MLIR lowering
|
||||
|
||||
%0 = x # EncryptedScalar<uint4>
|
||||
%1 = -3 # ClearScalar<int3>
|
||||
%2 = add(%0, %1) # EncryptedScalar<int4>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer outputs are supported
|
||||
return %2
|
||||
|
||||
""".strip() # noqa: E501 # pylint: disable=line-too-long
|
||||
)
|
||||
|
||||
|
||||
def test_compile_with_random_inputset(default_compilation_configuration):
|
||||
"""Test function for compile with random input set"""
|
||||
|
||||
@@ -1649,3 +1593,41 @@ def test_compile_and_run_correctness_with_negative_values(
|
||||
|
||||
args = [random.randint(low, high) for (low, high) in input_ranges]
|
||||
assert compiler_engine.run(*args) == function(*args)
|
||||
|
||||
|
||||
def check_equality_modulo(a, b, modulus):
|
||||
"""Check that (a mod modulus) == (b mod modulus)"""
|
||||
return (a % modulus) == (b % modulus)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,input_ranges,list_of_arg_names,modulus",
|
||||
[
|
||||
pytest.param(lambda x: x + (-20), ((0, 10),), ["x"], 128),
|
||||
pytest.param(lambda x: 10 + x * (-3), ((0, 20),), ["x"], 128),
|
||||
],
|
||||
)
|
||||
def test_compile_and_run_correctness_with_negative_results(
|
||||
function, input_ranges, list_of_arg_names, modulus, default_compilation_configuration
|
||||
):
|
||||
"""Test correctness of computations when the result is possibly negative: until #845 is fixed,
|
||||
results are currently only correct modulo a power of 2 (given by `modulus` parameter). Eg,
|
||||
instead of returning -3, the execution may return -3 mod 128 = 125."""
|
||||
|
||||
def data_gen(args):
|
||||
for prod in itertools.product(*args):
|
||||
yield prod
|
||||
|
||||
function_parameters = {
|
||||
arg_name: EncryptedScalar(Integer(64, False)) for arg_name in list_of_arg_names
|
||||
}
|
||||
|
||||
compiler_engine = compile_numpy_function(
|
||||
function,
|
||||
function_parameters,
|
||||
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
args = [random.randint(low, high) for (low, high) in input_ranges]
|
||||
assert check_equality_modulo(compiler_engine.run(*args), function(*args), modulus)
|
||||
|
||||
Reference in New Issue
Block a user