chore: fix test for dot with signed constants

closes #1123
This commit is contained in:
Arthur Meyre
2021-12-08 18:31:25 +01:00
parent 0c2c6f8298
commit 987f67b64c

View File

@@ -1177,17 +1177,16 @@ def test_compile_and_run_dot_correctness(size, input_range, default_compilation_
@pytest.mark.parametrize(
"size, input_range_x, input_range_y,modulus",
"size, input_range_x, input_range_y",
[
pytest.param(6, (0, 3), (-3, 3), 32),
pytest.param(3, (0, 3), (-7, 7), 64),
pytest.param(6, (0, 3), (-3, 3)),
pytest.param(3, (0, 3), (-7, 7)),
],
)
def test_compile_and_run_dot_correctness_with_signed_cst(
size, input_range_x, input_range_y, default_compilation_configuration, modulus
size, input_range_x, input_range_y, default_compilation_configuration
):
"""Test correctness of dot with signed constant tensor. Remark that for now, the results are
only correct modulo modulus"""
"""Test correctness of dot with signed constant tensor."""
low_x, high_x = input_range_x
low_y, high_y = input_range_y
@@ -1222,7 +1221,7 @@ def test_compile_and_run_dot_correctness_with_signed_cst(
]
for _ in range(128):
inputset.append(numpy.random.randint(low_x, high_x + 1))
inputset.append(numpy.random.randint(low_x, high_x + 1, size=shape))
if i == 0:
@@ -1242,6 +1241,11 @@ def test_compile_and_run_dot_correctness_with_signed_cst(
function, function_parameters, inputset, default_compilation_configuration
)
# compute modulus used for the output
output_bit_width = compiler_engine.op_graph.output_nodes[0].outputs[0].dtype.bit_width
# bit width + 1 padding bit
modulus = 2 ** (output_bit_width + 1)
for _ in range(5):
args = [
numpy.random.randint(low_x, high_x + 1, size=(size,), dtype=numpy.uint8),