mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat: update compile_torch_model to return compiled quantized module
closes #898
This commit is contained in:
@@ -3,8 +3,7 @@
|
||||
import numpy
|
||||
from torch import nn
|
||||
|
||||
from concrete.torch import NumpyModule
|
||||
|
||||
from ..torch import NumpyModule
|
||||
from .quantized_activations import QuantizedSigmoid
|
||||
from .quantized_array import QuantizedArray
|
||||
from .quantized_layers import QuantizedLinear
|
||||
|
||||
@@ -4,12 +4,10 @@ from typing import Optional, Union
|
||||
|
||||
import numpy
|
||||
|
||||
from concrete.common.compilation.artifacts import CompilationArtifacts
|
||||
from concrete.common.compilation.configuration import CompilationConfiguration
|
||||
from concrete.common.fhe_circuit import FHECircuit
|
||||
|
||||
from ..numpy import EncryptedTensor, UnsignedInteger
|
||||
from ..numpy.compile import compile_numpy_function
|
||||
from ..common.compilation.artifacts import CompilationArtifacts
|
||||
from ..common.compilation.configuration import CompilationConfiguration
|
||||
from ..common.fhe_circuit import FHECircuit
|
||||
from ..numpy.np_fhe_compiler import NPFHECompiler
|
||||
from .quantized_array import QuantizedArray
|
||||
|
||||
|
||||
@@ -95,6 +93,7 @@ class QuantizedModule:
|
||||
q_input: QuantizedArray,
|
||||
compilation_configuration: Optional[CompilationConfiguration] = None,
|
||||
compilation_artifacts: Optional[CompilationArtifacts] = None,
|
||||
show_mlir: bool = False,
|
||||
) -> FHECircuit:
|
||||
"""Compile the forward function of the module.
|
||||
|
||||
@@ -105,20 +104,25 @@ class QuantizedModule:
|
||||
compilation
|
||||
compilation_artifacts (Optional[CompilationArtifacts]): Artifacts object to fill during
|
||||
compilation
|
||||
show_mlir (bool, optional): if set, the MLIR produced by the converter and which is
|
||||
going to be sent to the compiler backend is shown on the screen, e.g., for debugging
|
||||
or demo. Defaults to False.
|
||||
|
||||
Returns:
|
||||
bool: Success flag from the compilation.
|
||||
FHECircuit: the compiled FHECircuit.
|
||||
"""
|
||||
|
||||
self.q_input = copy.deepcopy(q_input)
|
||||
self.forward_fhe = compile_numpy_function(
|
||||
compiler = NPFHECompiler(
|
||||
self.forward,
|
||||
{
|
||||
"q_x": EncryptedTensor(
|
||||
UnsignedInteger(self.q_input.n_bits), shape=(1, *self.q_input.qvalues.shape[1:])
|
||||
)
|
||||
"q_x": "encrypted",
|
||||
},
|
||||
[numpy.expand_dims(arr, 0) for arr in self.q_input.qvalues], # Super weird formatting
|
||||
compilation_configuration=compilation_configuration,
|
||||
compilation_artifacts=compilation_artifacts,
|
||||
compilation_configuration,
|
||||
compilation_artifacts,
|
||||
)
|
||||
compiler.eval_on_inputset((numpy.expand_dims(arr, 0) for arr in self.q_input.qvalues))
|
||||
|
||||
self.forward_fhe = compiler.get_compiled_fhe_circuit(show_mlir)
|
||||
|
||||
return self.forward_fhe
|
||||
|
||||
@@ -1,30 +1,55 @@
|
||||
"""torch compilation function."""
|
||||
|
||||
from typing import Optional
|
||||
from typing import Iterable, Optional, Tuple, Union
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
|
||||
from ..common.compilation import CompilationArtifacts, CompilationConfiguration
|
||||
from ..quantization import PostTrainingAffineQuantization, QuantizedArray
|
||||
from ..quantization import PostTrainingAffineQuantization, QuantizedArray, QuantizedModule
|
||||
from . import NumpyModule
|
||||
|
||||
TorchDataset = Union[Iterable[torch.Tensor], Iterable[Tuple[torch.Tensor, ...]]]
|
||||
NPDataset = Union[Iterable[numpy.ndarray], Iterable[Tuple[numpy.ndarray, ...]]]
|
||||
|
||||
|
||||
def convert_torch_tensor_or_numpy_array_to_numpy_array(
|
||||
torch_tensor_or_numpy_array: Union[torch.Tensor, numpy.ndarray]
|
||||
) -> numpy.ndarray:
|
||||
"""Convert a torch tensor or a numpy array to a numpy array.
|
||||
|
||||
Args:
|
||||
torch_tensor_or_numpy_array (Union[torch.Tensor, numpy.ndarray]): the value that is either
|
||||
a torch tensor or a numpy array.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: the value converted to a numpy array.
|
||||
"""
|
||||
return (
|
||||
torch_tensor_or_numpy_array
|
||||
if isinstance(torch_tensor_or_numpy_array, numpy.ndarray)
|
||||
else torch_tensor_or_numpy_array.cpu().numpy()
|
||||
)
|
||||
|
||||
|
||||
def compile_torch_model(
|
||||
torch_model: torch.nn.Module,
|
||||
torch_inputset: torch.FloatTensor,
|
||||
torch_inputset: Union[TorchDataset, NPDataset],
|
||||
compilation_configuration: Optional[CompilationConfiguration] = None,
|
||||
compilation_artifacts: Optional[CompilationArtifacts] = None,
|
||||
show_mlir: bool = False,
|
||||
n_bits=7,
|
||||
):
|
||||
) -> QuantizedModule:
|
||||
"""Take a model in torch, turn it to numpy, transform weights to integer.
|
||||
|
||||
Later, we'll compile the integer model.
|
||||
|
||||
Args:
|
||||
torch_model (torch.nn.Module): the model to quantize,
|
||||
torch_inputset (torch.FloatTensor): the inputset, in torch form
|
||||
torch_inputset (Union[TorchDataset, NPDataset]): the inputset, can contain either torch
|
||||
tensors or numpy.ndarray or tuples of those for networks requiring multiple inputs
|
||||
function_parameters_encrypted_status (Dict[str, Union[str, EncryptedStatus]]): a dict with
|
||||
the name of the parameter and its encrypted status
|
||||
compilation_configuration (CompilationConfiguration): Configuration object to use
|
||||
during compilation
|
||||
compilation_artifacts (CompilationArtifacts): Artifacts object to fill
|
||||
@@ -33,35 +58,37 @@ def compile_torch_model(
|
||||
to be sent to the compiler backend is shown on the screen, e.g., for debugging or demo
|
||||
n_bits: the number of bits for the quantization
|
||||
|
||||
Returns:
|
||||
QuantizedModule: The resulting compiled QuantizedModule.
|
||||
"""
|
||||
|
||||
# Create corresponding numpy model
|
||||
numpy_model = NumpyModule(torch_model)
|
||||
|
||||
# Torch input to numpy
|
||||
numpy_inputset = numpy.array(
|
||||
[
|
||||
tuple(val.cpu().numpy() for val in input_)
|
||||
if isinstance(input_, tuple)
|
||||
else tuple(input_.cpu().numpy())
|
||||
for input_ in torch_inputset
|
||||
]
|
||||
numpy_inputset = (
|
||||
tuple(convert_torch_tensor_or_numpy_array_to_numpy_array(val) for val in input_)
|
||||
if isinstance(input_, tuple)
|
||||
else convert_torch_tensor_or_numpy_array_to_numpy_array(input_)
|
||||
for input_ in torch_inputset
|
||||
)
|
||||
|
||||
numpy_inputset_as_single_array = numpy.concatenate(
|
||||
tuple(numpy.expand_dims(arr, 0) for arr in numpy_inputset)
|
||||
)
|
||||
|
||||
# Quantize with post-training static method, to have a model with integer weights
|
||||
post_training_quant = PostTrainingAffineQuantization(n_bits, numpy_model)
|
||||
quantized_model = post_training_quant.quantize_module(numpy_inputset)
|
||||
model_to_compile = quantized_model
|
||||
quantized_module = post_training_quant.quantize_module(numpy_inputset_as_single_array)
|
||||
|
||||
# Quantize input
|
||||
quantized_numpy_inputset = QuantizedArray(n_bits, numpy_inputset)
|
||||
quantized_numpy_inputset = QuantizedArray(n_bits, numpy_inputset_as_single_array)
|
||||
|
||||
# FIXME: just print, to avoid to have useless vars. Will be removed once we can finally compile
|
||||
# the model
|
||||
print(
|
||||
model_to_compile,
|
||||
quantized_module.compile(
|
||||
quantized_numpy_inputset,
|
||||
compilation_configuration,
|
||||
compilation_artifacts,
|
||||
show_mlir,
|
||||
)
|
||||
|
||||
return quantized_module
|
||||
|
||||
@@ -85,7 +85,7 @@ def test_quantized_module_compilation(
|
||||
homomorphic_predictions = homomorphic_predictions.reshape(dequant_predictions.shape)
|
||||
|
||||
# Make sure homomorphic_predictions are the same as dequant_predictions
|
||||
if numpy.isclose(homomorphic_predictions.ravel(), dequant_predictions.ravel()).all():
|
||||
if numpy.isclose(homomorphic_predictions, dequant_predictions).all():
|
||||
return
|
||||
|
||||
# Bad computation after nb_tries
|
||||
|
||||
@@ -1,62 +1,79 @@
|
||||
"""Tests for the torch to numpy module."""
|
||||
import numpy
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from concrete.quantization import QuantizedArray
|
||||
from concrete.torch.compile import compile_torch_model
|
||||
|
||||
# INPUT_OUTPUT_FEATURE is the number of input and output of each of the network layers.
|
||||
# (as well as the input of the network itself)
|
||||
INPUT_OUTPUT_FEATURE = [1, 2, 3]
|
||||
|
||||
|
||||
class FC(nn.Module):
|
||||
"""Torch model for the tests"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, input_output):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(in_features=32 * 32 * 3, out_features=128)
|
||||
self.fc1 = nn.Linear(in_features=input_output, out_features=input_output)
|
||||
self.sigmoid1 = nn.Sigmoid()
|
||||
self.fc2 = nn.Linear(in_features=128, out_features=64)
|
||||
self.sigmoid2 = nn.Sigmoid()
|
||||
self.fc3 = nn.Linear(in_features=64, out_features=64)
|
||||
self.sigmoid3 = nn.Sigmoid()
|
||||
self.fc4 = nn.Linear(in_features=64, out_features=64)
|
||||
self.sigmoid4 = nn.Sigmoid()
|
||||
self.fc5 = nn.Linear(in_features=64, out_features=10)
|
||||
self.fc2 = nn.Linear(in_features=input_output, out_features=input_output)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass."""
|
||||
out = self.fc1(x)
|
||||
out = self.sigmoid1(out)
|
||||
out = self.fc2(out)
|
||||
out = self.sigmoid2(out)
|
||||
out = self.fc3(out)
|
||||
out = self.sigmoid3(out)
|
||||
out = self.fc4(out)
|
||||
out = self.sigmoid4(out)
|
||||
out = self.fc5(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model, input_shape",
|
||||
[
|
||||
pytest.param(FC, (100, 32 * 32 * 3)),
|
||||
],
|
||||
"model",
|
||||
[pytest.param(FC)],
|
||||
)
|
||||
def test_compile_torch(model, input_shape, default_compilation_configuration, seed_torch):
|
||||
@pytest.mark.parametrize(
|
||||
"input_output_feature",
|
||||
[pytest.param(input_output_feature) for input_output_feature in INPUT_OUTPUT_FEATURE],
|
||||
)
|
||||
def test_compile_torch(input_output_feature, model, seed_torch, default_compilation_configuration):
|
||||
"""Test the different model architecture from torch numpy."""
|
||||
|
||||
# Seed torch
|
||||
seed_torch()
|
||||
|
||||
# Define the torch model
|
||||
torch_fc_model = model()
|
||||
n_bits = 2
|
||||
|
||||
# Define an input shape (n_examples, n_features)
|
||||
n_examples = 10
|
||||
|
||||
# Define the torch model
|
||||
torch_fc_model = model(input_output_feature)
|
||||
# Create random input
|
||||
torch_inputset = torch.randn(input_shape)
|
||||
inputset = [numpy.random.uniform(-1, 1, size=input_output_feature) for _ in range(n_examples)]
|
||||
|
||||
# Compile
|
||||
compile_torch_model(
|
||||
quantized_numpy_module = compile_torch_model(
|
||||
torch_fc_model,
|
||||
torch_inputset,
|
||||
inputset,
|
||||
default_compilation_configuration,
|
||||
n_bits=n_bits,
|
||||
)
|
||||
|
||||
# Compare predictions between FHE and QuantizedModule
|
||||
clear_predictions = []
|
||||
homomorphic_predictions = []
|
||||
for numpy_input in inputset:
|
||||
q_input = QuantizedArray(n_bits, numpy_input)
|
||||
x_q = q_input.qvalues
|
||||
clear_predictions.append(quantized_numpy_module.forward(x_q))
|
||||
homomorphic_predictions.append(
|
||||
quantized_numpy_module.forward_fhe.run(numpy.array([x_q]).astype(numpy.uint8))
|
||||
)
|
||||
|
||||
clear_predictions = numpy.array(clear_predictions)
|
||||
homomorphic_predictions = numpy.array(homomorphic_predictions)
|
||||
|
||||
# Make sure homomorphic_predictions are the same as dequant_predictions
|
||||
assert numpy.array_equal(homomorphic_predictions, clear_predictions)
|
||||
|
||||
Reference in New Issue
Block a user