Files
concrete/tests/quantization/test_quantized_activations.py
2021-11-10 18:25:31 +01:00

43 lines
1.2 KiB
Python

"""Tests for the quantized activation functions."""
import numpy
import pytest
from concrete.quantization import QuantizedArray, QuantizedSigmoid
N_BITS_ATOL_TUPLE_LIST = [
(32, 10 ** -2),
(28, 10 ** -2),
(20, 10 ** -2),
(16, 10 ** -1),
(8, 10 ** -0),
(4, 10 ** -0),
]
@pytest.mark.parametrize(
"n_bits, atol",
[pytest.param(n_bits, atol) for n_bits, atol in N_BITS_ATOL_TUPLE_LIST],
)
@pytest.mark.parametrize(
"quant_activation, values",
[pytest.param(QuantizedSigmoid, numpy.random.uniform(size=(10, 40, 20)))],
)
def test_activations(quant_activation, values, n_bits, atol):
"""Test activation functions."""
q_inputs = QuantizedArray(n_bits, values)
quant_sigmoid = quant_activation(n_bits)
quant_sigmoid.calibrate(values)
expected_output = quant_sigmoid.q_out.values
q_output = quant_sigmoid(q_inputs)
qvalues = q_output.qvalues
# Quantized values must be contained between 0 and 2**n_bits - 1.
assert numpy.max(qvalues) <= 2 ** n_bits - 1
assert numpy.min(qvalues) >= 0
# Dequantized values must be close to original values
dequant_values = q_output.dequant()
# Check that all values are close
assert numpy.isclose(dequant_values, expected_output, atol=atol).all()