From dfc762c2e208fcea9c82faf4da9fa4f2d15434b1 Mon Sep 17 00:00:00 2001 From: jfrery Date: Wed, 24 Nov 2021 11:19:15 +0100 Subject: [PATCH] feat: add QuantizedReLU6 as a supported activation function --- concrete/quantization/__init__.py | 2 +- .../quantization/quantized_activations.py | 24 +++++++++++++++++++ .../test_quantized_activations.py | 22 +++++++++++++---- 3 files changed, 42 insertions(+), 6 deletions(-) diff --git a/concrete/quantization/__init__.py b/concrete/quantization/__init__.py index fdf6a7432..c9facc1dc 100644 --- a/concrete/quantization/__init__.py +++ b/concrete/quantization/__init__.py @@ -1,6 +1,6 @@ """Modules for quantization.""" from .post_training import PostTrainingAffineQuantization -from .quantized_activations import QuantizedSigmoid +from .quantized_activations import QuantizedReLU6, QuantizedSigmoid from .quantized_array import QuantizedArray from .quantized_layers import QuantizedLinear from .quantized_module import QuantizedModule diff --git a/concrete/quantization/quantized_activations.py b/concrete/quantization/quantized_activations.py index 5bba4174d..3f02bcea7 100644 --- a/concrete/quantization/quantized_activations.py +++ b/concrete/quantization/quantized_activations.py @@ -85,3 +85,27 @@ class QuantizedSigmoid(QuantizedActivation): q_out = self.quant_output(quant_sigmoid) return q_out + + +class QuantizedReLU6(QuantizedActivation): + """Quantized ReLU6 activation function.""" + + def calibrate(self, x: numpy.ndarray): + x = numpy.minimum(numpy.maximum(0, x), 6) + self.q_out = QuantizedArray(self.n_bits, x) + + def __call__(self, q_input: QuantizedArray) -> QuantizedArray: + """Process the forward pass of the quantized ReLU6. + + Args: + q_input (QuantizedArray): Quantized input. + + Returns: + q_out (QuantizedArray): Quantized output. + """ + + quant_relu6 = self.dequant_input(q_input) + quant_relu6 = numpy.minimum(numpy.maximum(0, quant_relu6), 6) + + q_out = self.quant_output(quant_relu6) + return q_out diff --git a/tests/quantization/test_quantized_activations.py b/tests/quantization/test_quantized_activations.py index e6b25c02d..349267f1b 100644 --- a/tests/quantization/test_quantized_activations.py +++ b/tests/quantization/test_quantized_activations.py @@ -2,7 +2,7 @@ import numpy import pytest -from concrete.quantization import QuantizedArray, QuantizedSigmoid +from concrete.quantization import QuantizedArray, QuantizedReLU6, QuantizedSigmoid N_BITS_ATOL_TUPLE_LIST = [ (32, 10 ** -2), @@ -19,12 +19,24 @@ N_BITS_ATOL_TUPLE_LIST = [ [pytest.param(n_bits, atol) for n_bits, atol in N_BITS_ATOL_TUPLE_LIST], ) @pytest.mark.parametrize( - "quant_activation, values", - [pytest.param(QuantizedSigmoid, numpy.random.uniform(size=(10, 40, 20)))], + "input_range", + [pytest.param((-1, 1)), pytest.param((-2, 2)), pytest.param((-10, 10)), pytest.param((0, 20))], +) +@pytest.mark.parametrize( + "input_shape", + [pytest.param((10, 40, 20)), pytest.param((100, 400))], +) +@pytest.mark.parametrize( + "quant_activation", + [ + pytest.param(QuantizedSigmoid), + pytest.param(QuantizedReLU6), + ], ) @pytest.mark.parametrize("is_signed", [pytest.param(True), pytest.param(False)]) -def test_activations(quant_activation, values, n_bits, atol, is_signed): +def test_activations(quant_activation, input_shape, input_range, n_bits, atol, is_signed): """Test activation functions.""" + values = numpy.random.uniform(input_range[0], input_range[1], size=input_shape) q_inputs = QuantizedArray(n_bits, values, is_signed) quant_sigmoid = quant_activation(n_bits) quant_sigmoid.calibrate(values) @@ -40,4 +52,4 @@ def test_activations(quant_activation, values, n_bits, atol, is_signed): dequant_values = q_output.dequant() # Check that all values are close - assert numpy.isclose(dequant_values, expected_output, atol=atol).all() + assert numpy.isclose(dequant_values.ravel(), expected_output.ravel(), atol=atol).all()