From 4e6c57766abb94046ac16d41bb9c29f6fdf66ad4 Mon Sep 17 00:00:00 2001 From: Umut Date: Fri, 17 Dec 2021 17:20:09 +0300 Subject: [PATCH] refactor: improve isclose used in quantization tests --- tests/quantization/test_quantized_array.py | 31 +++++++++++++--------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/tests/quantization/test_quantized_array.py b/tests/quantization/test_quantized_array.py index cf22bc4be..df1be2741 100644 --- a/tests/quantization/test_quantized_array.py +++ b/tests/quantization/test_quantized_array.py @@ -4,23 +4,14 @@ import pytest from concrete.quantization import QuantizedArray -N_BITS_ATOL_TUPLE_LIST = [ - (32, 10 ** -2), - (28, 10 ** -2), - (20, 10 ** -2), - (16, 10 ** -1), - (8, 10 ** -0), - (4, 10 ** -0), -] - @pytest.mark.parametrize( - "n_bits, atol", - [pytest.param(n_bits, atol) for n_bits, atol in N_BITS_ATOL_TUPLE_LIST], + "n_bits", + [32, 28, 20, 16, 8, 4], ) @pytest.mark.parametrize("is_signed", [pytest.param(True), pytest.param(False)]) @pytest.mark.parametrize("values", [pytest.param(numpy.random.randn(2000))]) -def test_quant_dequant_update(values, n_bits, atol, is_signed, check_array_equality): +def test_quant_dequant_update(values, n_bits, is_signed, check_array_equality): """Test the quant and dequant function.""" quant_array = QuantizedArray(n_bits, values, is_signed) @@ -34,7 +25,21 @@ def test_quant_dequant_update(values, n_bits, atol, is_signed, check_array_equal dequant_values = quant_array.dequant() # Check that all values are close - assert numpy.isclose(dequant_values, values, atol=atol).all() + tolerance = quant_array.scale / 2 + assert numpy.isclose(dequant_values, values, atol=tolerance).all() + + # Explain the choice of tolerance + # This test checks the values are quantized and dequantized correctly + # Each quantization have a maximum error per quantized value an it's `scale / 2` + + # To give an intuition, let's say you have the scale of 0.5 + # the range `[a + 0.00, a + 0.25]` will be quantized into 0, dequantized into `a + 0.00` + # the range `[a + 0.25, a + 0.75]` will be quantized into 1, dequantized into `a + 0.50` + # the range `[a + 0.75, a + 1.25]` will be quantized into 2, dequantized into `a + 1.00` + # ... + + # So for each quantization-then-dequantization operation, + # the maximum error is `0.25`, which is `scale / 2` # Test update functions new_values = numpy.array([0.3, 0.5, -1.2, -3.4])