refactor: improve isclose used in quantization tests

This commit is contained in:
Umut
2021-12-17 17:20:09 +03:00
parent 75d231fada
commit 4e6c57766a

View File

@@ -4,23 +4,14 @@ import pytest
from concrete.quantization import QuantizedArray
N_BITS_ATOL_TUPLE_LIST = [
(32, 10 ** -2),
(28, 10 ** -2),
(20, 10 ** -2),
(16, 10 ** -1),
(8, 10 ** -0),
(4, 10 ** -0),
]
@pytest.mark.parametrize(
"n_bits, atol",
[pytest.param(n_bits, atol) for n_bits, atol in N_BITS_ATOL_TUPLE_LIST],
"n_bits",
[32, 28, 20, 16, 8, 4],
)
@pytest.mark.parametrize("is_signed", [pytest.param(True), pytest.param(False)])
@pytest.mark.parametrize("values", [pytest.param(numpy.random.randn(2000))])
def test_quant_dequant_update(values, n_bits, atol, is_signed, check_array_equality):
def test_quant_dequant_update(values, n_bits, is_signed, check_array_equality):
"""Test the quant and dequant function."""
quant_array = QuantizedArray(n_bits, values, is_signed)
@@ -34,7 +25,21 @@ def test_quant_dequant_update(values, n_bits, atol, is_signed, check_array_equal
dequant_values = quant_array.dequant()
# Check that all values are close
assert numpy.isclose(dequant_values, values, atol=atol).all()
tolerance = quant_array.scale / 2
assert numpy.isclose(dequant_values, values, atol=tolerance).all()
# Explain the choice of tolerance
# This test checks the values are quantized and dequantized correctly
# Each quantization have a maximum error per quantized value an it's `scale / 2`
# To give an intuition, let's say you have the scale of 0.5
# the range `[a + 0.00, a + 0.25]` will be quantized into 0, dequantized into `a + 0.00`
# the range `[a + 0.25, a + 0.75]` will be quantized into 1, dequantized into `a + 0.50`
# the range `[a + 0.75, a + 1.25]` will be quantized into 2, dequantized into `a + 1.00`
# ...
# So for each quantization-then-dequantization operation,
# the maximum error is `0.25`, which is `scale / 2`
# Test update functions
new_values = numpy.array([0.3, 0.5, -1.2, -3.4])