From 247327b277fa26b202bc50471dea0558292ac03a Mon Sep 17 00:00:00 2001 From: Benoit Chevallier-Mames Date: Thu, 18 Nov 2021 15:25:47 +0100 Subject: [PATCH] chore: fix #951 by making sure the dot product is never too large. closes #951 --- tests/numpy/test_compile.py | 36 +++++++++++++++++++++++++++--------- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/tests/numpy/test_compile.py b/tests/numpy/test_compile.py index 951b9bd65..d3019ef2f 100644 --- a/tests/numpy/test_compile.py +++ b/tests/numpy/test_compile.py @@ -1060,8 +1060,8 @@ def test_compile_and_run_dot_correctness(size, input_range, default_compilation_ @pytest.mark.parametrize( "size, input_range_x, input_range_y,modulus", [ - pytest.param(10, (0, 3), (-3, 3), 32), - pytest.param(5, (0, 3), (-7, 7), 64), + pytest.param(6, (0, 3), (-3, 3), 32), + pytest.param(3, (0, 3), (-7, 7), 64), ], ) def test_compile_and_run_dot_correctness_with_signed_cst( @@ -1074,13 +1074,12 @@ def test_compile_and_run_dot_correctness_with_signed_cst( low_y, high_y = input_range_y shape = (size,) - inputset = [ - (numpy.zeros(shape, dtype=numpy.uint32),), - (numpy.ones(shape, dtype=numpy.uint32) * low_x,), - (numpy.ones(shape, dtype=numpy.uint32) * high_x,), - ] - for _ in range(8): - inputset.append((numpy.random.randint(low_x, high_x + 1),)) + # Check that never, the dot goes too high + # For this, we simplify our check knowing that low_x >= 0. Under this condition, the maximal + # value is for the dot is size * max(abs(high_x * low_y), abs(high_x * high_y)). And we want + # is to be less than 64, to have a signed value on strictly less than 8b + assert low_x >= 0 + assert size * max(abs(high_x * low_y), abs(high_x * high_y)) < 64 function_parameters = { "x": EncryptedTensor(Integer(64, False), shape), @@ -1089,18 +1088,37 @@ def test_compile_and_run_dot_correctness_with_signed_cst( constant1 = numpy.random.randint(low_y, high_y + 1, size=(size,)) constant2 = numpy.random.randint(low_y, high_y + 1, size=(size,)) + worst_x_1_1 = numpy.where(constant1 < 0, 0, high_x) + worst_x_1_2 = numpy.where(constant1 > 0, 0, high_x) + + worst_x_2_1 = numpy.where(constant2 < 0, 0, high_x) + worst_x_2_2 = numpy.where(constant2 > 0, 0, high_x) + for i in range(2): + inputset = [ + (numpy.zeros(shape, dtype=numpy.uint32),), + (numpy.ones(shape, dtype=numpy.uint32) * low_x,), + (numpy.ones(shape, dtype=numpy.uint32) * high_x,), + ] + + for _ in range(128): + inputset.append((numpy.random.randint(low_x, high_x + 1),)) + if i == 0: def function(x): return numpy.dot(x, constant1) + inputset.extend([(worst_x_1_1,), (worst_x_1_2,)]) + else: def function(x): return numpy.dot(constant2, x) + inputset.extend([(worst_x_2_1,), (worst_x_2_2,)]) + compiler_engine = compile_numpy_function( function, function_parameters, inputset, default_compilation_configuration )