test(execution): add tests with negative results

for now, results are correct mod 128, but wrong without this modulo reduction. Will be later fixed
by the compiler.

closes #844
refs #845
This commit is contained in:
Benoit Chevallier-Mames
2021-11-12 15:32:39 +01:00
committed by Benoit Chevallier
parent 239f66eb46
commit 50c1ceb6db
2 changed files with 46 additions and 58 deletions

View File

@@ -115,8 +115,14 @@ def check_node_compatibility_with_mlir(
if is_output:
for out in outputs:
if not value_is_unsigned_integer(out):
return "only unsigned integer outputs are supported"
# For signed values and waiting for a real fix (#845): what is returned by the compiler
# is not the (possibly negative) result r, but the always-positive (r mod 2**t), where t
# is the bitwidth of r
# We currently can't fail on the following assert, but let it for possible changes in
# the future
if not value_is_integer(out):
return "only integer outputs are supported" # pragma: no cover
else:
for out in outputs:
# We currently can't fail on the following assert, but let it for possible changes in

View File

@@ -1033,24 +1033,6 @@ def test_compile_function_with_direct_tlu_overflow(default_compilation_configura
@pytest.mark.parametrize(
"function,parameters,inputset,match",
[
pytest.param(
lambda x: 1 - x,
{"x": EncryptedScalar(Integer(3, is_signed=False))},
[(i,) for i in range(8)],
(
"""
function you are trying to compile isn't supported for MLIR lowering
%0 = 1 # ClearScalar<uint1>
%1 = x # EncryptedScalar<uint3>
%2 = sub(%0, %1) # EncryptedScalar<int4>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer outputs are supported
return %2
""".strip() # noqa: E501
),
),
pytest.param(
lambda x, y: numpy.dot(x, y),
{
@@ -1476,44 +1458,6 @@ return %2
)
def test_failure_for_signed_output(default_compilation_configuration):
"""Test that we don't accept signed output"""
function = lambda x: x + (-3) # noqa: E731
input_ranges = ((0, 10),)
list_of_arg_names = ["x"]
def data_gen(args):
for prod in itertools.product(*args):
yield prod
function_parameters = {
arg_name: EncryptedScalar(Integer(64, False)) for arg_name in list_of_arg_names
}
with pytest.raises(RuntimeError) as excinfo:
compile_numpy_function(
function,
function_parameters,
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
default_compilation_configuration,
)
assert (
str(excinfo.value)
== """
function you are trying to compile isn't supported for MLIR lowering
%0 = x # EncryptedScalar<uint4>
%1 = -3 # ClearScalar<int3>
%2 = add(%0, %1) # EncryptedScalar<int4>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer outputs are supported
return %2
""".strip() # noqa: E501 # pylint: disable=line-too-long
)
def test_compile_with_random_inputset(default_compilation_configuration):
"""Test function for compile with random input set"""
@@ -1649,3 +1593,41 @@ def test_compile_and_run_correctness_with_negative_values(
args = [random.randint(low, high) for (low, high) in input_ranges]
assert compiler_engine.run(*args) == function(*args)
def check_equality_modulo(a, b, modulus):
"""Check that (a mod modulus) == (b mod modulus)"""
return (a % modulus) == (b % modulus)
@pytest.mark.parametrize(
"function,input_ranges,list_of_arg_names,modulus",
[
pytest.param(lambda x: x + (-20), ((0, 10),), ["x"], 128),
pytest.param(lambda x: 10 + x * (-3), ((0, 20),), ["x"], 128),
],
)
def test_compile_and_run_correctness_with_negative_results(
function, input_ranges, list_of_arg_names, modulus, default_compilation_configuration
):
"""Test correctness of computations when the result is possibly negative: until #845 is fixed,
results are currently only correct modulo a power of 2 (given by `modulus` parameter). Eg,
instead of returning -3, the execution may return -3 mod 128 = 125."""
def data_gen(args):
for prod in itertools.product(*args):
yield prod
function_parameters = {
arg_name: EncryptedScalar(Integer(64, False)) for arg_name in list_of_arg_names
}
compiler_engine = compile_numpy_function(
function,
function_parameters,
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
default_compilation_configuration,
)
args = [random.randint(low, high) for (low, high) in input_ranges]
assert check_equality_modulo(compiler_engine.run(*args), function(*args), modulus)