mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat: static post training quantization and quantization module
This commit is contained in:
@@ -1,4 +1,6 @@
|
||||
"""Modules for quantization."""
|
||||
from .post_training import PostTrainingAffineQuantization
|
||||
from .quantized_activations import QuantizedSigmoid
|
||||
from .quantized_array import QuantizedArray
|
||||
from .quantized_layers import QuantizedLinear
|
||||
from .quantized_module import QuantizedModule
|
||||
|
||||
106
concrete/quantization/post_training.py
Normal file
106
concrete/quantization/post_training.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""Post Training Quantization methods."""
|
||||
|
||||
import numpy
|
||||
from torch import nn
|
||||
|
||||
from concrete.torch import NumpyModule
|
||||
|
||||
from .quantized_activations import QuantizedSigmoid
|
||||
from .quantized_array import QuantizedArray
|
||||
from .quantized_layers import QuantizedLinear
|
||||
from .quantized_module import QuantizedModule
|
||||
|
||||
|
||||
class PostTrainingAffineQuantization:
|
||||
"""Post-training Affine Quantization."""
|
||||
|
||||
IMPLEMENTED_MODULES = {nn.Linear, nn.Sigmoid}
|
||||
|
||||
quant_layers_dict: dict
|
||||
n_bits: int
|
||||
quant_params: dict
|
||||
numpy_model: NumpyModule
|
||||
|
||||
def __init__(self, n_bits: int, numpy_model: NumpyModule):
|
||||
"""Create the quantized version of numpy module.
|
||||
|
||||
Args:
|
||||
n_bits (int): Number of bits to quantize the model. Currently this
|
||||
n_bits will be used for all activation/inputs/weights
|
||||
numpy_model (NumpyModule): Model in numpy.
|
||||
|
||||
Returns:
|
||||
QuantizedModule: A quantized version of the numpy model.
|
||||
"""
|
||||
self.quant_layers_dict = {}
|
||||
self.n_bits = n_bits
|
||||
self.quant_params = {}
|
||||
self.numpy_model = numpy_model
|
||||
|
||||
def quantize_module(self, calibration_data: numpy.ndarray) -> QuantizedModule:
|
||||
"""Quantize numpy module.
|
||||
|
||||
Following https://arxiv.org/abs/1712.05877 guidelines.
|
||||
|
||||
Args:
|
||||
calibration_data (numpy.ndarray): Data that will be used to compute the bounds,
|
||||
scales and zero point values for every quantized
|
||||
object.
|
||||
|
||||
Returns:
|
||||
QuantizedModule: Quantized numpy module
|
||||
"""
|
||||
# First transform all parameters to their quantized version
|
||||
self._quantize_params()
|
||||
# Quantize and calibrate each output layer/activation
|
||||
self._quantize_layers(calibration_data=calibration_data)
|
||||
# Create quantized module from self.quant_layers_dict
|
||||
return QuantizedModule(self.quant_layers_dict)
|
||||
|
||||
def _quantize_params(self):
|
||||
"""Transform all floating points parameters to integers."""
|
||||
|
||||
for name, params in self.numpy_model.numpy_module_dict.items():
|
||||
self.quant_params[name] = QuantizedArray(self.n_bits, params)
|
||||
|
||||
def _calibrate_layers_activation(self, name, q_function, calibration_data):
|
||||
# Calibrate the output of the layer
|
||||
q_function.calibrate(calibration_data)
|
||||
# Store the learned quantized layer
|
||||
self.quant_layers_dict[name] = q_function
|
||||
# Create new calibration data (output of the previous layer)
|
||||
q_calibration_data = QuantizedArray(self.n_bits, calibration_data)
|
||||
# Dequantize to have the value in clear and ready for next calibration
|
||||
return q_function(q_calibration_data).dequant()
|
||||
|
||||
def _quantize_layers(self, calibration_data: numpy.ndarray):
|
||||
"""Compute all parameters for the static post-training quantization.
|
||||
|
||||
Does a forward pass over a batch of data and compute all
|
||||
quantization parameters for activations and layers.
|
||||
"""
|
||||
for name, layer in self.numpy_model.torch_model.named_children():
|
||||
|
||||
if isinstance(layer, nn.Linear):
|
||||
# Create a QuantizedLinear layer
|
||||
q_weights = self.quant_params[f"{name}.weight"]
|
||||
q_bias = self.quant_params[f"{name}.bias"]
|
||||
q_layer = QuantizedLinear(self.n_bits, q_weights, q_bias)
|
||||
# Calibrate and get new calibration_data for next layer/activation
|
||||
calibration_data = self._calibrate_layers_activation(
|
||||
name, q_layer, calibration_data
|
||||
)
|
||||
elif isinstance(layer, nn.Sigmoid):
|
||||
# Create a new quantized layer (based on type(layer))
|
||||
q_sigmoid = QuantizedSigmoid(n_bits=self.n_bits)
|
||||
calibration_data = self._calibrate_layers_activation(
|
||||
name, q_sigmoid, calibration_data
|
||||
)
|
||||
else: # pragma: no cover
|
||||
# If we find a layer that has not been implemented we throw an error
|
||||
hf_m_names = sorted(module.__name__ for module in self.IMPLEMENTED_MODULES)
|
||||
raise ValueError(
|
||||
f"The following module is currently not implemented: {type(layer).__name__}"
|
||||
f"Please stick to the available quantized modules:"
|
||||
f"{', '.join(hf_m_names)}."
|
||||
)
|
||||
28
concrete/quantization/quantized_module.py
Normal file
28
concrete/quantization/quantized_module.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""QuantizedModule API."""
|
||||
import copy
|
||||
|
||||
from .quantized_array import QuantizedArray
|
||||
|
||||
|
||||
class QuantizedModule:
|
||||
"""Inference for a quantized model."""
|
||||
|
||||
def __init__(self, quant_layers_dict: dict):
|
||||
self.quant_layers_dict = copy.deepcopy(quant_layers_dict)
|
||||
|
||||
def __call__(self, x: QuantizedArray) -> QuantizedArray:
|
||||
return self.forward(x)
|
||||
|
||||
def forward(self, q_x: QuantizedArray) -> QuantizedArray:
|
||||
"""Forward pass with numpy function only.
|
||||
|
||||
Args:
|
||||
q_x (QuantizedArray): QuantizedArray containing the inputs.
|
||||
|
||||
Returns:
|
||||
(QuantizedArray): Prediction of the quantized model
|
||||
"""
|
||||
for _, layer in self.quant_layers_dict.items():
|
||||
q_x = layer(q_x)
|
||||
|
||||
return q_x
|
||||
@@ -39,7 +39,6 @@ class NumpyModule:
|
||||
def convert_to_numpy(self):
|
||||
"""Transform all parameters from torch tensor to numpy arrays."""
|
||||
self.numpy_module_dict = {}
|
||||
self.numpy_module_quant_dict = {}
|
||||
|
||||
for name, weights in self.torch_model.state_dict().items():
|
||||
params = weights.detach().numpy()
|
||||
|
||||
107
tests/quantization/test_quantized_module.py
Normal file
107
tests/quantization/test_quantized_module.py
Normal file
@@ -0,0 +1,107 @@
|
||||
"""Tests for the quantized module."""
|
||||
import numpy
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from concrete.quantization import PostTrainingAffineQuantization, QuantizedArray
|
||||
from concrete.torch import NumpyModule
|
||||
|
||||
|
||||
class CNN(nn.Module):
|
||||
"""Torch CNN model for the tests."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(3, 6, 5)
|
||||
self.pool = nn.AvgPool2d(2, 2)
|
||||
self.conv2 = nn.Conv2d(6, 16, 5)
|
||||
self.fc1 = nn.Linear(16 * 5 * 5, 120)
|
||||
self.fc2 = nn.Linear(120, 84)
|
||||
self.fc3 = nn.Linear(84, 10)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass."""
|
||||
x = self.pool(torch.relu(self.conv1(x)))
|
||||
x = self.pool(torch.relu(self.conv2(x)))
|
||||
x = torch.flatten(x, 1)
|
||||
x = torch.relu(self.fc1(x))
|
||||
x = torch.relu(self.fc2(x))
|
||||
x = self.fc3(x)
|
||||
return x
|
||||
|
||||
|
||||
class FC(nn.Module):
|
||||
"""Torch model for the tests"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(in_features=32 * 32 * 3, out_features=128)
|
||||
self.sigmoid1 = nn.Sigmoid()
|
||||
self.fc2 = nn.Linear(in_features=128, out_features=64)
|
||||
self.sigmoid2 = nn.Sigmoid()
|
||||
self.fc3 = nn.Linear(in_features=64, out_features=64)
|
||||
self.sigmoid3 = nn.Sigmoid()
|
||||
self.fc4 = nn.Linear(in_features=64, out_features=64)
|
||||
self.sigmoid4 = nn.Sigmoid()
|
||||
self.fc5 = nn.Linear(in_features=64, out_features=10)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass."""
|
||||
out = self.fc1(x)
|
||||
out = self.sigmoid1(out)
|
||||
out = self.fc2(out)
|
||||
out = self.sigmoid2(out)
|
||||
out = self.fc3(out)
|
||||
out = self.sigmoid3(out)
|
||||
out = self.fc4(out)
|
||||
out = self.sigmoid4(out)
|
||||
out = self.fc5(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
N_BITS_ATOL_TUPLE_LIST = [
|
||||
(28, 10 ** -2),
|
||||
(20, 10 ** -2),
|
||||
(16, 10 ** -1),
|
||||
(8, 10 ** -0),
|
||||
(4, 10 ** -0),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"n_bits, atol",
|
||||
[pytest.param(n_bits, atol) for n_bits, atol in N_BITS_ATOL_TUPLE_LIST],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"model, input_shape",
|
||||
[
|
||||
pytest.param(FC, (100, 32 * 32 * 3)),
|
||||
],
|
||||
)
|
||||
def test_quantized_linear(model, input_shape, n_bits, atol):
|
||||
"""Test the quantized module with a post-training static quantization.
|
||||
|
||||
With n_bits>>0 we expect the results of the quantized module
|
||||
to be the same as the standard module.
|
||||
"""
|
||||
# Define the torch model
|
||||
torch_fc_model = model()
|
||||
# Create random input
|
||||
numpy_input = numpy.random.uniform(size=input_shape)
|
||||
# Create corresponding numpy model
|
||||
numpy_fc_model = NumpyModule(torch_fc_model)
|
||||
# Predict with real model
|
||||
numpy_prediction = numpy_fc_model(numpy_input)
|
||||
# Quantize with post-training static method
|
||||
post_training_quant = PostTrainingAffineQuantization(n_bits, numpy_fc_model)
|
||||
quantized_model = post_training_quant.quantize_module(numpy_input)
|
||||
# Quantize input
|
||||
q_input = QuantizedArray(n_bits, numpy_input)
|
||||
# Get quantized prediction
|
||||
q_prediction = quantized_model(q_input)
|
||||
# Dequantize to get back to real values
|
||||
dequant_prediction = q_prediction.dequant()
|
||||
|
||||
assert numpy.isclose(numpy_prediction, dequant_prediction, atol=atol).all()
|
||||
Reference in New Issue
Block a user