mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
chore: fix #951 by making sure the dot product is never too large.
closes #951
This commit is contained in:
committed by
Benoit Chevallier
parent
8a27525a64
commit
247327b277
@@ -1060,8 +1060,8 @@ def test_compile_and_run_dot_correctness(size, input_range, default_compilation_
|
||||
@pytest.mark.parametrize(
|
||||
"size, input_range_x, input_range_y,modulus",
|
||||
[
|
||||
pytest.param(10, (0, 3), (-3, 3), 32),
|
||||
pytest.param(5, (0, 3), (-7, 7), 64),
|
||||
pytest.param(6, (0, 3), (-3, 3), 32),
|
||||
pytest.param(3, (0, 3), (-7, 7), 64),
|
||||
],
|
||||
)
|
||||
def test_compile_and_run_dot_correctness_with_signed_cst(
|
||||
@@ -1074,13 +1074,12 @@ def test_compile_and_run_dot_correctness_with_signed_cst(
|
||||
low_y, high_y = input_range_y
|
||||
shape = (size,)
|
||||
|
||||
inputset = [
|
||||
(numpy.zeros(shape, dtype=numpy.uint32),),
|
||||
(numpy.ones(shape, dtype=numpy.uint32) * low_x,),
|
||||
(numpy.ones(shape, dtype=numpy.uint32) * high_x,),
|
||||
]
|
||||
for _ in range(8):
|
||||
inputset.append((numpy.random.randint(low_x, high_x + 1),))
|
||||
# Check that never, the dot goes too high
|
||||
# For this, we simplify our check knowing that low_x >= 0. Under this condition, the maximal
|
||||
# value is for the dot is size * max(abs(high_x * low_y), abs(high_x * high_y)). And we want
|
||||
# is to be less than 64, to have a signed value on strictly less than 8b
|
||||
assert low_x >= 0
|
||||
assert size * max(abs(high_x * low_y), abs(high_x * high_y)) < 64
|
||||
|
||||
function_parameters = {
|
||||
"x": EncryptedTensor(Integer(64, False), shape),
|
||||
@@ -1089,18 +1088,37 @@ def test_compile_and_run_dot_correctness_with_signed_cst(
|
||||
constant1 = numpy.random.randint(low_y, high_y + 1, size=(size,))
|
||||
constant2 = numpy.random.randint(low_y, high_y + 1, size=(size,))
|
||||
|
||||
worst_x_1_1 = numpy.where(constant1 < 0, 0, high_x)
|
||||
worst_x_1_2 = numpy.where(constant1 > 0, 0, high_x)
|
||||
|
||||
worst_x_2_1 = numpy.where(constant2 < 0, 0, high_x)
|
||||
worst_x_2_2 = numpy.where(constant2 > 0, 0, high_x)
|
||||
|
||||
for i in range(2):
|
||||
|
||||
inputset = [
|
||||
(numpy.zeros(shape, dtype=numpy.uint32),),
|
||||
(numpy.ones(shape, dtype=numpy.uint32) * low_x,),
|
||||
(numpy.ones(shape, dtype=numpy.uint32) * high_x,),
|
||||
]
|
||||
|
||||
for _ in range(128):
|
||||
inputset.append((numpy.random.randint(low_x, high_x + 1),))
|
||||
|
||||
if i == 0:
|
||||
|
||||
def function(x):
|
||||
return numpy.dot(x, constant1)
|
||||
|
||||
inputset.extend([(worst_x_1_1,), (worst_x_1_2,)])
|
||||
|
||||
else:
|
||||
|
||||
def function(x):
|
||||
return numpy.dot(constant2, x)
|
||||
|
||||
inputset.extend([(worst_x_2_1,), (worst_x_2_2,)])
|
||||
|
||||
compiler_engine = compile_numpy_function(
|
||||
function, function_parameters, inputset, default_compilation_configuration
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user