From 987f67b64c4bdc6a9ceab8197a815047ce5f9749 Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Wed, 8 Dec 2021 18:31:25 +0100 Subject: [PATCH] chore: fix test for dot with signed constants closes #1123 --- tests/numpy/test_compile.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/tests/numpy/test_compile.py b/tests/numpy/test_compile.py index 1b36510b5..023f99b69 100644 --- a/tests/numpy/test_compile.py +++ b/tests/numpy/test_compile.py @@ -1177,17 +1177,16 @@ def test_compile_and_run_dot_correctness(size, input_range, default_compilation_ @pytest.mark.parametrize( - "size, input_range_x, input_range_y,modulus", + "size, input_range_x, input_range_y", [ - pytest.param(6, (0, 3), (-3, 3), 32), - pytest.param(3, (0, 3), (-7, 7), 64), + pytest.param(6, (0, 3), (-3, 3)), + pytest.param(3, (0, 3), (-7, 7)), ], ) def test_compile_and_run_dot_correctness_with_signed_cst( - size, input_range_x, input_range_y, default_compilation_configuration, modulus + size, input_range_x, input_range_y, default_compilation_configuration ): - """Test correctness of dot with signed constant tensor. Remark that for now, the results are - only correct modulo modulus""" + """Test correctness of dot with signed constant tensor.""" low_x, high_x = input_range_x low_y, high_y = input_range_y @@ -1222,7 +1221,7 @@ def test_compile_and_run_dot_correctness_with_signed_cst( ] for _ in range(128): - inputset.append(numpy.random.randint(low_x, high_x + 1)) + inputset.append(numpy.random.randint(low_x, high_x + 1, size=shape)) if i == 0: @@ -1242,6 +1241,11 @@ def test_compile_and_run_dot_correctness_with_signed_cst( function, function_parameters, inputset, default_compilation_configuration ) + # compute modulus used for the output + output_bit_width = compiler_engine.op_graph.output_nodes[0].outputs[0].dtype.bit_width + # bit width + 1 padding bit + modulus = 2 ** (output_bit_width + 1) + for _ in range(5): args = [ numpy.random.randint(low_x, high_x + 1, size=(size,), dtype=numpy.uint8),