Files
concrete/tests/quantization/test_quantized_array.py
2021-12-06 16:44:32 +03:00

55 lines
1.8 KiB
Python

"""Tests for the quantized array/tensors."""
import numpy
import pytest
from concrete.quantization import QuantizedArray
N_BITS_ATOL_TUPLE_LIST = [
(32, 10 ** -2),
(28, 10 ** -2),
(20, 10 ** -2),
(16, 10 ** -1),
(8, 10 ** -0),
(4, 10 ** -0),
]
@pytest.mark.parametrize(
"n_bits, atol",
[pytest.param(n_bits, atol) for n_bits, atol in N_BITS_ATOL_TUPLE_LIST],
)
@pytest.mark.parametrize("is_signed", [pytest.param(True), pytest.param(False)])
@pytest.mark.parametrize("values", [pytest.param(numpy.random.randn(2000))])
def test_quant_dequant_update(values, n_bits, atol, is_signed, check_array_equality):
"""Test the quant and dequant function."""
quant_array = QuantizedArray(n_bits, values, is_signed)
qvalues = quant_array.quant()
# Quantized values must be contained between 0 and 2**n_bits
assert numpy.max(qvalues) <= 2 ** (n_bits) - 1 - quant_array.offset
assert numpy.min(qvalues) >= -quant_array.offset
# Dequantized values must be close to original values
dequant_values = quant_array.dequant()
# Check that all values are close
assert numpy.isclose(dequant_values, values, atol=atol).all()
# Test update functions
new_values = numpy.array([0.3, 0.5, -1.2, -3.4])
new_qvalues_ = quant_array.update_values(new_values)
# Make sure the shape changed for the qvalues
assert new_qvalues_.shape != qvalues.shape
new_qvalues = numpy.array([1, 4, 7, 29])
new_values_updated = quant_array.update_qvalues(new_qvalues)
# Make sure that we can see at least one change.
assert not numpy.array_equal(new_qvalues, new_qvalues_)
assert not numpy.array_equal(new_values, new_values_updated)
# Check that the __call__ returns also the qvalues.
check_array_equality(quant_array(), new_qvalues)