feat: static post training quantization and quantization module

This commit is contained in:
Arthur Meyre
2021-11-16 11:08:40 +01:00
committed by jfrery
parent bfa309a455
commit 507ccd05c5
5 changed files with 243 additions and 1 deletions

View File

@@ -1,4 +1,6 @@
"""Modules for quantization."""
from .post_training import PostTrainingAffineQuantization
from .quantized_activations import QuantizedSigmoid
from .quantized_array import QuantizedArray
from .quantized_layers import QuantizedLinear
from .quantized_module import QuantizedModule

View File

@@ -0,0 +1,106 @@
"""Post Training Quantization methods."""
import numpy
from torch import nn
from concrete.torch import NumpyModule
from .quantized_activations import QuantizedSigmoid
from .quantized_array import QuantizedArray
from .quantized_layers import QuantizedLinear
from .quantized_module import QuantizedModule
class PostTrainingAffineQuantization:
"""Post-training Affine Quantization."""
IMPLEMENTED_MODULES = {nn.Linear, nn.Sigmoid}
quant_layers_dict: dict
n_bits: int
quant_params: dict
numpy_model: NumpyModule
def __init__(self, n_bits: int, numpy_model: NumpyModule):
"""Create the quantized version of numpy module.
Args:
n_bits (int): Number of bits to quantize the model. Currently this
n_bits will be used for all activation/inputs/weights
numpy_model (NumpyModule): Model in numpy.
Returns:
QuantizedModule: A quantized version of the numpy model.
"""
self.quant_layers_dict = {}
self.n_bits = n_bits
self.quant_params = {}
self.numpy_model = numpy_model
def quantize_module(self, calibration_data: numpy.ndarray) -> QuantizedModule:
"""Quantize numpy module.
Following https://arxiv.org/abs/1712.05877 guidelines.
Args:
calibration_data (numpy.ndarray): Data that will be used to compute the bounds,
scales and zero point values for every quantized
object.
Returns:
QuantizedModule: Quantized numpy module
"""
# First transform all parameters to their quantized version
self._quantize_params()
# Quantize and calibrate each output layer/activation
self._quantize_layers(calibration_data=calibration_data)
# Create quantized module from self.quant_layers_dict
return QuantizedModule(self.quant_layers_dict)
def _quantize_params(self):
"""Transform all floating points parameters to integers."""
for name, params in self.numpy_model.numpy_module_dict.items():
self.quant_params[name] = QuantizedArray(self.n_bits, params)
def _calibrate_layers_activation(self, name, q_function, calibration_data):
# Calibrate the output of the layer
q_function.calibrate(calibration_data)
# Store the learned quantized layer
self.quant_layers_dict[name] = q_function
# Create new calibration data (output of the previous layer)
q_calibration_data = QuantizedArray(self.n_bits, calibration_data)
# Dequantize to have the value in clear and ready for next calibration
return q_function(q_calibration_data).dequant()
def _quantize_layers(self, calibration_data: numpy.ndarray):
"""Compute all parameters for the static post-training quantization.
Does a forward pass over a batch of data and compute all
quantization parameters for activations and layers.
"""
for name, layer in self.numpy_model.torch_model.named_children():
if isinstance(layer, nn.Linear):
# Create a QuantizedLinear layer
q_weights = self.quant_params[f"{name}.weight"]
q_bias = self.quant_params[f"{name}.bias"]
q_layer = QuantizedLinear(self.n_bits, q_weights, q_bias)
# Calibrate and get new calibration_data for next layer/activation
calibration_data = self._calibrate_layers_activation(
name, q_layer, calibration_data
)
elif isinstance(layer, nn.Sigmoid):
# Create a new quantized layer (based on type(layer))
q_sigmoid = QuantizedSigmoid(n_bits=self.n_bits)
calibration_data = self._calibrate_layers_activation(
name, q_sigmoid, calibration_data
)
else: # pragma: no cover
# If we find a layer that has not been implemented we throw an error
hf_m_names = sorted(module.__name__ for module in self.IMPLEMENTED_MODULES)
raise ValueError(
f"The following module is currently not implemented: {type(layer).__name__}"
f"Please stick to the available quantized modules:"
f"{', '.join(hf_m_names)}."
)

View File

@@ -0,0 +1,28 @@
"""QuantizedModule API."""
import copy
from .quantized_array import QuantizedArray
class QuantizedModule:
"""Inference for a quantized model."""
def __init__(self, quant_layers_dict: dict):
self.quant_layers_dict = copy.deepcopy(quant_layers_dict)
def __call__(self, x: QuantizedArray) -> QuantizedArray:
return self.forward(x)
def forward(self, q_x: QuantizedArray) -> QuantizedArray:
"""Forward pass with numpy function only.
Args:
q_x (QuantizedArray): QuantizedArray containing the inputs.
Returns:
(QuantizedArray): Prediction of the quantized model
"""
for _, layer in self.quant_layers_dict.items():
q_x = layer(q_x)
return q_x

View File

@@ -39,7 +39,6 @@ class NumpyModule:
def convert_to_numpy(self):
"""Transform all parameters from torch tensor to numpy arrays."""
self.numpy_module_dict = {}
self.numpy_module_quant_dict = {}
for name, weights in self.torch_model.state_dict().items():
params = weights.detach().numpy()

View File

@@ -0,0 +1,107 @@
"""Tests for the quantized module."""
import numpy
import pytest
import torch
from torch import nn
from concrete.quantization import PostTrainingAffineQuantization, QuantizedArray
from concrete.torch import NumpyModule
class CNN(nn.Module):
"""Torch CNN model for the tests."""
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.AvgPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
"""Forward pass."""
x = self.pool(torch.relu(self.conv1(x)))
x = self.pool(torch.relu(self.conv2(x)))
x = torch.flatten(x, 1)
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
class FC(nn.Module):
"""Torch model for the tests"""
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(in_features=32 * 32 * 3, out_features=128)
self.sigmoid1 = nn.Sigmoid()
self.fc2 = nn.Linear(in_features=128, out_features=64)
self.sigmoid2 = nn.Sigmoid()
self.fc3 = nn.Linear(in_features=64, out_features=64)
self.sigmoid3 = nn.Sigmoid()
self.fc4 = nn.Linear(in_features=64, out_features=64)
self.sigmoid4 = nn.Sigmoid()
self.fc5 = nn.Linear(in_features=64, out_features=10)
def forward(self, x):
"""Forward pass."""
out = self.fc1(x)
out = self.sigmoid1(out)
out = self.fc2(out)
out = self.sigmoid2(out)
out = self.fc3(out)
out = self.sigmoid3(out)
out = self.fc4(out)
out = self.sigmoid4(out)
out = self.fc5(out)
return out
N_BITS_ATOL_TUPLE_LIST = [
(28, 10 ** -2),
(20, 10 ** -2),
(16, 10 ** -1),
(8, 10 ** -0),
(4, 10 ** -0),
]
@pytest.mark.parametrize(
"n_bits, atol",
[pytest.param(n_bits, atol) for n_bits, atol in N_BITS_ATOL_TUPLE_LIST],
)
@pytest.mark.parametrize(
"model, input_shape",
[
pytest.param(FC, (100, 32 * 32 * 3)),
],
)
def test_quantized_linear(model, input_shape, n_bits, atol):
"""Test the quantized module with a post-training static quantization.
With n_bits>>0 we expect the results of the quantized module
to be the same as the standard module.
"""
# Define the torch model
torch_fc_model = model()
# Create random input
numpy_input = numpy.random.uniform(size=input_shape)
# Create corresponding numpy model
numpy_fc_model = NumpyModule(torch_fc_model)
# Predict with real model
numpy_prediction = numpy_fc_model(numpy_input)
# Quantize with post-training static method
post_training_quant = PostTrainingAffineQuantization(n_bits, numpy_fc_model)
quantized_model = post_training_quant.quantize_module(numpy_input)
# Quantize input
q_input = QuantizedArray(n_bits, numpy_input)
# Get quantized prediction
q_prediction = quantized_model(q_input)
# Dequantize to get back to real values
dequant_prediction = q_prediction.dequant()
assert numpy.isclose(numpy_prediction, dequant_prediction, atol=atol).all()