test: add tests with dot and negative tensors

refs #891
This commit is contained in:
Benoit Chevallier-Mames
2021-11-17 18:13:44 +01:00
committed by Benoit Chevallier
parent f6aadd37e9
commit bfa309a455
2 changed files with 76 additions and 24 deletions

View File

@@ -105,8 +105,8 @@ def check_node_compatibility_with_mlir(
elif isinstance(node, intermediate.Dot): # constraints for dot product
assert_true(len(inputs) == 2)
if not value_is_unsigned_integer(inputs[0]) or not value_is_unsigned_integer(inputs[1]):
return "only unsigned integer dot product is supported"
if not value_is_integer(inputs[0]) or not value_is_integer(inputs[1]):
return "only integer dot product is supported"
elif isinstance(node, intermediate.IndexConstant): # constraints for constant indexing
assert_true(len(outputs) == 1)

View File

@@ -1006,6 +1006,61 @@ def test_compile_and_run_dot_correctness(size, input_range, default_compilation_
assert compiler_engine.run(*args) == function(*args)
@pytest.mark.parametrize(
"size, input_range_x, input_range_y,modulus",
[
pytest.param(10, (0, 3), (-3, 3), 32),
pytest.param(5, (0, 3), (-7, 7), 64),
],
)
def test_compile_and_run_dot_correctness_with_signed_cst(
size, input_range_x, input_range_y, default_compilation_configuration, modulus
):
"""Test correctness of dot with signed constant tensor. Remark that for now, the results are
only correct modulo modulus"""
low_x, high_x = input_range_x
low_y, high_y = input_range_y
shape = (size,)
inputset = [
(numpy.zeros(shape, dtype=numpy.uint32),),
(numpy.ones(shape, dtype=numpy.uint32) * low_x,),
(numpy.ones(shape, dtype=numpy.uint32) * high_x,),
]
for _ in range(8):
inputset.append((numpy.random.randint(low_x, high_x + 1),))
function_parameters = {
"x": EncryptedTensor(Integer(64, False), shape),
}
constant1 = numpy.random.randint(low_y, high_y + 1, size=(size,))
constant2 = numpy.random.randint(low_y, high_y + 1, size=(size,))
for i in range(2):
if i == 0:
def function(x):
return numpy.dot(x, constant1)
else:
def function(x):
return numpy.dot(constant2, x)
compiler_engine = compile_numpy_function(
function, function_parameters, inputset, default_compilation_configuration
)
for _ in range(5):
args = [
numpy.random.randint(low_x, high_x + 1, size=(size,), dtype=numpy.uint8),
]
assert check_equality_modulo(compiler_engine.run(*args), function(*args), modulus)
@pytest.mark.parametrize(
"size,input_range",
[
@@ -1258,37 +1313,34 @@ def test_compile_function_with_direct_tlu_overflow(default_compilation_configura
"function,parameters,inputset,match",
[
pytest.param(
lambda x, y: numpy.dot(x, y),
lambda x: numpy.dot(x, numpy.array([-1.5])),
{
"x": EncryptedTensor(Integer(2, is_signed=True), shape=(1,)),
"y": EncryptedTensor(Integer(2, is_signed=True), shape=(1,)),
"x": EncryptedTensor(Integer(2, is_signed=False), shape=(1,)),
},
[
(numpy.array([-1]), numpy.array([-1])),
(numpy.array([-1]), numpy.array([0])),
(numpy.array([0]), numpy.array([-1])),
(numpy.array([0]), numpy.array([0])),
(numpy.array([1]), numpy.array([1])),
(numpy.array([1]), numpy.array([0])),
(numpy.array([0]), numpy.array([1])),
(numpy.array([0]), numpy.array([0])),
(numpy.array([-2]), numpy.array([-2])),
(numpy.array([-2]), numpy.array([1])),
(numpy.array([1]),),
(numpy.array([1]),),
(numpy.array([0]),),
(numpy.array([0]),),
(numpy.array([1]),),
(numpy.array([1]),),
(numpy.array([0]),),
(numpy.array([0]),),
(numpy.array([2]),),
(numpy.array([2]),),
],
(
"""
function you are trying to compile isn't supported for MLIR lowering
function you are trying to compile isn't supported for MLIR lowering
%0 = x # EncryptedTensor<int2, shape=(1,)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer inputs are supported
%1 = y # EncryptedTensor<int2, shape=(1,)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer inputs are supported
%2 = dot(%0, %1) # EncryptedScalar<int4>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer dot product is supported
%0 = x # EncryptedTensor<uint2, shape=(1,)>
%1 = [-1.5] # ClearTensor<float64, shape=(1,)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer constants are supported
%2 = dot(%0, %1) # EncryptedScalar<float64>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer dot product is supported
return %2
""".strip() # noqa: E501
""".strip() # noqa: E501
),
),
pytest.param(