mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
refactor: improve isclose used in quantization tests
This commit is contained in:
@@ -4,23 +4,14 @@ import pytest
|
||||
|
||||
from concrete.quantization import QuantizedArray
|
||||
|
||||
N_BITS_ATOL_TUPLE_LIST = [
|
||||
(32, 10 ** -2),
|
||||
(28, 10 ** -2),
|
||||
(20, 10 ** -2),
|
||||
(16, 10 ** -1),
|
||||
(8, 10 ** -0),
|
||||
(4, 10 ** -0),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"n_bits, atol",
|
||||
[pytest.param(n_bits, atol) for n_bits, atol in N_BITS_ATOL_TUPLE_LIST],
|
||||
"n_bits",
|
||||
[32, 28, 20, 16, 8, 4],
|
||||
)
|
||||
@pytest.mark.parametrize("is_signed", [pytest.param(True), pytest.param(False)])
|
||||
@pytest.mark.parametrize("values", [pytest.param(numpy.random.randn(2000))])
|
||||
def test_quant_dequant_update(values, n_bits, atol, is_signed, check_array_equality):
|
||||
def test_quant_dequant_update(values, n_bits, is_signed, check_array_equality):
|
||||
"""Test the quant and dequant function."""
|
||||
|
||||
quant_array = QuantizedArray(n_bits, values, is_signed)
|
||||
@@ -34,7 +25,21 @@ def test_quant_dequant_update(values, n_bits, atol, is_signed, check_array_equal
|
||||
dequant_values = quant_array.dequant()
|
||||
|
||||
# Check that all values are close
|
||||
assert numpy.isclose(dequant_values, values, atol=atol).all()
|
||||
tolerance = quant_array.scale / 2
|
||||
assert numpy.isclose(dequant_values, values, atol=tolerance).all()
|
||||
|
||||
# Explain the choice of tolerance
|
||||
# This test checks the values are quantized and dequantized correctly
|
||||
# Each quantization have a maximum error per quantized value an it's `scale / 2`
|
||||
|
||||
# To give an intuition, let's say you have the scale of 0.5
|
||||
# the range `[a + 0.00, a + 0.25]` will be quantized into 0, dequantized into `a + 0.00`
|
||||
# the range `[a + 0.25, a + 0.75]` will be quantized into 1, dequantized into `a + 0.50`
|
||||
# the range `[a + 0.75, a + 1.25]` will be quantized into 2, dequantized into `a + 1.00`
|
||||
# ...
|
||||
|
||||
# So for each quantization-then-dequantization operation,
|
||||
# the maximum error is `0.25`, which is `scale / 2`
|
||||
|
||||
# Test update functions
|
||||
new_values = numpy.array([0.3, 0.5, -1.2, -3.4])
|
||||
|
||||
Reference in New Issue
Block a user