feat: add QuantizedReLU6 as a supported activation function

This commit is contained in:
jfrery
2021-11-24 11:19:15 +01:00
committed by jfrery
parent f53d374d1f
commit dfc762c2e2
3 changed files with 42 additions and 6 deletions

View File

@@ -1,6 +1,6 @@
"""Modules for quantization."""
from .post_training import PostTrainingAffineQuantization
from .quantized_activations import QuantizedSigmoid
from .quantized_activations import QuantizedReLU6, QuantizedSigmoid
from .quantized_array import QuantizedArray
from .quantized_layers import QuantizedLinear
from .quantized_module import QuantizedModule

View File

@@ -85,3 +85,27 @@ class QuantizedSigmoid(QuantizedActivation):
q_out = self.quant_output(quant_sigmoid)
return q_out
class QuantizedReLU6(QuantizedActivation):
"""Quantized ReLU6 activation function."""
def calibrate(self, x: numpy.ndarray):
x = numpy.minimum(numpy.maximum(0, x), 6)
self.q_out = QuantizedArray(self.n_bits, x)
def __call__(self, q_input: QuantizedArray) -> QuantizedArray:
"""Process the forward pass of the quantized ReLU6.
Args:
q_input (QuantizedArray): Quantized input.
Returns:
q_out (QuantizedArray): Quantized output.
"""
quant_relu6 = self.dequant_input(q_input)
quant_relu6 = numpy.minimum(numpy.maximum(0, quant_relu6), 6)
q_out = self.quant_output(quant_relu6)
return q_out

View File

@@ -2,7 +2,7 @@
import numpy
import pytest
from concrete.quantization import QuantizedArray, QuantizedSigmoid
from concrete.quantization import QuantizedArray, QuantizedReLU6, QuantizedSigmoid
N_BITS_ATOL_TUPLE_LIST = [
(32, 10 ** -2),
@@ -19,12 +19,24 @@ N_BITS_ATOL_TUPLE_LIST = [
[pytest.param(n_bits, atol) for n_bits, atol in N_BITS_ATOL_TUPLE_LIST],
)
@pytest.mark.parametrize(
"quant_activation, values",
[pytest.param(QuantizedSigmoid, numpy.random.uniform(size=(10, 40, 20)))],
"input_range",
[pytest.param((-1, 1)), pytest.param((-2, 2)), pytest.param((-10, 10)), pytest.param((0, 20))],
)
@pytest.mark.parametrize(
"input_shape",
[pytest.param((10, 40, 20)), pytest.param((100, 400))],
)
@pytest.mark.parametrize(
"quant_activation",
[
pytest.param(QuantizedSigmoid),
pytest.param(QuantizedReLU6),
],
)
@pytest.mark.parametrize("is_signed", [pytest.param(True), pytest.param(False)])
def test_activations(quant_activation, values, n_bits, atol, is_signed):
def test_activations(quant_activation, input_shape, input_range, n_bits, atol, is_signed):
"""Test activation functions."""
values = numpy.random.uniform(input_range[0], input_range[1], size=input_shape)
q_inputs = QuantizedArray(n_bits, values, is_signed)
quant_sigmoid = quant_activation(n_bits)
quant_sigmoid.calibrate(values)
@@ -40,4 +52,4 @@ def test_activations(quant_activation, values, n_bits, atol, is_signed):
dequant_values = q_output.dequant()
# Check that all values are close
assert numpy.isclose(dequant_values, expected_output, atol=atol).all()
assert numpy.isclose(dequant_values.ravel(), expected_output.ravel(), atol=atol).all()