diff --git a/concrete/quantization/post_training.py b/concrete/quantization/post_training.py index 4736fcb1b..0035e258a 100644 --- a/concrete/quantization/post_training.py +++ b/concrete/quantization/post_training.py @@ -4,7 +4,7 @@ import numpy from torch import nn from ..torch import NumpyModule -from .quantized_activations import QuantizedSigmoid +from .quantized_activations import QuantizedReLU6, QuantizedSigmoid from .quantized_array import QuantizedArray from .quantized_layers import QuantizedLinear from .quantized_module import QuantizedModule @@ -13,7 +13,7 @@ from .quantized_module import QuantizedModule class PostTrainingAffineQuantization: """Post-training Affine Quantization.""" - IMPLEMENTED_MODULES = {nn.Linear, nn.Sigmoid} + IMPLEMENTED_MODULES = {nn.Linear, nn.Sigmoid, nn.ReLU6} quant_layers_dict: dict n_bits: int @@ -104,6 +104,10 @@ class PostTrainingAffineQuantization: calibration_data = self._calibrate_layers_activation( name, q_sigmoid, calibration_data ) + elif isinstance(layer, nn.ReLU6): + # Create a new quantized layer (based on type(layer)) + q_relu = QuantizedReLU6(n_bits=self.n_bits) + calibration_data = self._calibrate_layers_activation(name, q_relu, calibration_data) else: # pragma: no cover # If we find a layer that has not been implemented we throw an error hf_m_names = sorted(module.__name__ for module in self.IMPLEMENTED_MODULES) diff --git a/concrete/torch/numpy_module.py b/concrete/torch/numpy_module.py index d4fe541b3..717b2399b 100644 --- a/concrete/torch/numpy_module.py +++ b/concrete/torch/numpy_module.py @@ -6,7 +6,7 @@ from torch import nn class NumpyModule: """General interface to transform a torch.nn.Module to numpy module.""" - IMPLEMENTED_MODULES = {nn.Linear, nn.Sigmoid} + IMPLEMENTED_MODULES = {nn.Linear, nn.Sigmoid, nn.ReLU6} def __init__(self, torch_model: nn.Module): """Initialize our numpy module. @@ -68,5 +68,6 @@ class NumpyModule: ) elif isinstance(layer, nn.Sigmoid): x = 1 / (1 + numpy.exp(-x)) - + elif isinstance(layer, nn.ReLU6): + x = numpy.minimum(numpy.maximum(0, x), 6) return x diff --git a/tests/torch/test_compile_torch.py b/tests/torch/test_compile_torch.py index a92ca091b..2b41e5ffa 100644 --- a/tests/torch/test_compile_torch.py +++ b/tests/torch/test_compile_torch.py @@ -14,21 +14,28 @@ INPUT_OUTPUT_FEATURE = [1, 2, 3] class FC(nn.Module): """Torch model for the tests""" - def __init__(self, input_output): + def __init__(self, input_output, activation_function): super().__init__() self.fc1 = nn.Linear(in_features=input_output, out_features=input_output) - self.sigmoid1 = nn.Sigmoid() + self.act_f = activation_function() self.fc2 = nn.Linear(in_features=input_output, out_features=input_output) def forward(self, x): """Forward pass.""" out = self.fc1(x) - out = self.sigmoid1(out) + out = self.act_f(out) out = self.fc2(out) return out +@pytest.mark.parametrize( + "activation_function", + [ + pytest.param(nn.Sigmoid, id="sigmoid"), + pytest.param(nn.ReLU6, id="relu"), + ], +) @pytest.mark.parametrize( "model", [pytest.param(FC)], @@ -40,6 +47,7 @@ class FC(nn.Module): def test_compile_torch( input_output_feature, model, + activation_function, seed_torch, default_compilation_configuration, check_is_good_execution, @@ -55,7 +63,7 @@ def test_compile_torch( n_examples = 50 # Define the torch model - torch_fc_model = model(input_output_feature) + torch_fc_model = model(input_output_feature, activation_function) # Create random input inputset = [ numpy.random.uniform(-100, 100, size=input_output_feature) for _ in range(n_examples) diff --git a/tests/torch/test_torch_to_numpy.py b/tests/torch/test_torch_to_numpy.py index 127d6bf01..3dbbb430b 100644 --- a/tests/torch/test_torch_to_numpy.py +++ b/tests/torch/test_torch_to_numpy.py @@ -33,28 +33,28 @@ class CNN(nn.Module): class FC(nn.Module): """Torch model for the tests""" - def __init__(self): + def __init__(self, activation_function): super().__init__() self.fc1 = nn.Linear(in_features=32 * 32 * 3, out_features=128) - self.sigmoid1 = nn.Sigmoid() + self.act_1 = activation_function() self.fc2 = nn.Linear(in_features=128, out_features=64) - self.sigmoid2 = nn.Sigmoid() + self.act_2 = activation_function() self.fc3 = nn.Linear(in_features=64, out_features=64) - self.sigmoid3 = nn.Sigmoid() + self.act_3 = activation_function() self.fc4 = nn.Linear(in_features=64, out_features=64) - self.sigmoid4 = nn.Sigmoid() + self.act_4 = activation_function() self.fc5 = nn.Linear(in_features=64, out_features=10) def forward(self, x): """Forward pass.""" out = self.fc1(x) - out = self.sigmoid1(out) + out = self.act_1(out) out = self.fc2(out) - out = self.sigmoid2(out) + out = self.act_2(out) out = self.fc3(out) - out = self.sigmoid3(out) + out = self.act_3(out) out = self.fc4(out) - out = self.sigmoid4(out) + out = self.act_4(out) out = self.fc5(out) return out @@ -66,13 +66,20 @@ class FC(nn.Module): pytest.param(FC, (100, 32 * 32 * 3)), ], ) -def test_torch_to_numpy(model, input_shape, seed_torch): +@pytest.mark.parametrize( + "activation_function", + [ + pytest.param(nn.Sigmoid, id="sigmoid"), + pytest.param(nn.ReLU6, id="relu"), + ], +) +def test_torch_to_numpy(model, input_shape, activation_function, seed_torch): """Test the different model architecture from torch numpy.""" # Seed torch seed_torch() # Define the torch model - torch_fc_model = model() + torch_fc_model = model(activation_function) # Create random input torch_input_1 = torch.randn(input_shape) # Predict with torch model