chore: fix #951 by making sure the dot product is never too large.

closes #951
This commit is contained in:
Benoit Chevallier-Mames
2021-11-18 15:25:47 +01:00
committed by Benoit Chevallier
parent 8a27525a64
commit 247327b277

View File

@@ -1060,8 +1060,8 @@ def test_compile_and_run_dot_correctness(size, input_range, default_compilation_
@pytest.mark.parametrize(
"size, input_range_x, input_range_y,modulus",
[
pytest.param(10, (0, 3), (-3, 3), 32),
pytest.param(5, (0, 3), (-7, 7), 64),
pytest.param(6, (0, 3), (-3, 3), 32),
pytest.param(3, (0, 3), (-7, 7), 64),
],
)
def test_compile_and_run_dot_correctness_with_signed_cst(
@@ -1074,13 +1074,12 @@ def test_compile_and_run_dot_correctness_with_signed_cst(
low_y, high_y = input_range_y
shape = (size,)
inputset = [
(numpy.zeros(shape, dtype=numpy.uint32),),
(numpy.ones(shape, dtype=numpy.uint32) * low_x,),
(numpy.ones(shape, dtype=numpy.uint32) * high_x,),
]
for _ in range(8):
inputset.append((numpy.random.randint(low_x, high_x + 1),))
# Check that never, the dot goes too high
# For this, we simplify our check knowing that low_x >= 0. Under this condition, the maximal
# value is for the dot is size * max(abs(high_x * low_y), abs(high_x * high_y)). And we want
# is to be less than 64, to have a signed value on strictly less than 8b
assert low_x >= 0
assert size * max(abs(high_x * low_y), abs(high_x * high_y)) < 64
function_parameters = {
"x": EncryptedTensor(Integer(64, False), shape),
@@ -1089,18 +1088,37 @@ def test_compile_and_run_dot_correctness_with_signed_cst(
constant1 = numpy.random.randint(low_y, high_y + 1, size=(size,))
constant2 = numpy.random.randint(low_y, high_y + 1, size=(size,))
worst_x_1_1 = numpy.where(constant1 < 0, 0, high_x)
worst_x_1_2 = numpy.where(constant1 > 0, 0, high_x)
worst_x_2_1 = numpy.where(constant2 < 0, 0, high_x)
worst_x_2_2 = numpy.where(constant2 > 0, 0, high_x)
for i in range(2):
inputset = [
(numpy.zeros(shape, dtype=numpy.uint32),),
(numpy.ones(shape, dtype=numpy.uint32) * low_x,),
(numpy.ones(shape, dtype=numpy.uint32) * high_x,),
]
for _ in range(128):
inputset.append((numpy.random.randint(low_x, high_x + 1),))
if i == 0:
def function(x):
return numpy.dot(x, constant1)
inputset.extend([(worst_x_1_1,), (worst_x_1_2,)])
else:
def function(x):
return numpy.dot(constant2, x)
inputset.extend([(worst_x_2_1,), (worst_x_2_2,)])
compiler_engine = compile_numpy_function(
function, function_parameters, inputset, default_compilation_configuration
)