From c7255dfd6653aee37bea79ed00d1ed8448656e0b Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Thu, 2 Dec 2021 16:17:52 +0100 Subject: [PATCH] feat: update compile_torch_model to return compiled quantized module closes #898 --- concrete/quantization/post_training.py | 3 +- concrete/quantization/quantized_module.py | 32 ++++++----- concrete/torch/compile.py | 65 ++++++++++++++------- tests/quantization/test_compilation.py | 2 +- tests/torch/test_compile_torch.py | 69 ++++++++++++++--------- 5 files changed, 109 insertions(+), 62 deletions(-) diff --git a/concrete/quantization/post_training.py b/concrete/quantization/post_training.py index 972af4df9..f60fb8614 100644 --- a/concrete/quantization/post_training.py +++ b/concrete/quantization/post_training.py @@ -3,8 +3,7 @@ import numpy from torch import nn -from concrete.torch import NumpyModule - +from ..torch import NumpyModule from .quantized_activations import QuantizedSigmoid from .quantized_array import QuantizedArray from .quantized_layers import QuantizedLinear diff --git a/concrete/quantization/quantized_module.py b/concrete/quantization/quantized_module.py index c310ac8de..a6210034f 100644 --- a/concrete/quantization/quantized_module.py +++ b/concrete/quantization/quantized_module.py @@ -4,12 +4,10 @@ from typing import Optional, Union import numpy -from concrete.common.compilation.artifacts import CompilationArtifacts -from concrete.common.compilation.configuration import CompilationConfiguration -from concrete.common.fhe_circuit import FHECircuit - -from ..numpy import EncryptedTensor, UnsignedInteger -from ..numpy.compile import compile_numpy_function +from ..common.compilation.artifacts import CompilationArtifacts +from ..common.compilation.configuration import CompilationConfiguration +from ..common.fhe_circuit import FHECircuit +from ..numpy.np_fhe_compiler import NPFHECompiler from .quantized_array import QuantizedArray @@ -95,6 +93,7 @@ class QuantizedModule: q_input: QuantizedArray, compilation_configuration: Optional[CompilationConfiguration] = None, compilation_artifacts: Optional[CompilationArtifacts] = None, + show_mlir: bool = False, ) -> FHECircuit: """Compile the forward function of the module. @@ -105,20 +104,25 @@ class QuantizedModule: compilation compilation_artifacts (Optional[CompilationArtifacts]): Artifacts object to fill during compilation + show_mlir (bool, optional): if set, the MLIR produced by the converter and which is + going to be sent to the compiler backend is shown on the screen, e.g., for debugging + or demo. Defaults to False. + Returns: - bool: Success flag from the compilation. + FHECircuit: the compiled FHECircuit. """ self.q_input = copy.deepcopy(q_input) - self.forward_fhe = compile_numpy_function( + compiler = NPFHECompiler( self.forward, { - "q_x": EncryptedTensor( - UnsignedInteger(self.q_input.n_bits), shape=(1, *self.q_input.qvalues.shape[1:]) - ) + "q_x": "encrypted", }, - [numpy.expand_dims(arr, 0) for arr in self.q_input.qvalues], # Super weird formatting - compilation_configuration=compilation_configuration, - compilation_artifacts=compilation_artifacts, + compilation_configuration, + compilation_artifacts, ) + compiler.eval_on_inputset((numpy.expand_dims(arr, 0) for arr in self.q_input.qvalues)) + + self.forward_fhe = compiler.get_compiled_fhe_circuit(show_mlir) + return self.forward_fhe diff --git a/concrete/torch/compile.py b/concrete/torch/compile.py index ffbaa11fb..a9883f369 100644 --- a/concrete/torch/compile.py +++ b/concrete/torch/compile.py @@ -1,30 +1,55 @@ """torch compilation function.""" -from typing import Optional +from typing import Iterable, Optional, Tuple, Union import numpy import torch from ..common.compilation import CompilationArtifacts, CompilationConfiguration -from ..quantization import PostTrainingAffineQuantization, QuantizedArray +from ..quantization import PostTrainingAffineQuantization, QuantizedArray, QuantizedModule from . import NumpyModule +TorchDataset = Union[Iterable[torch.Tensor], Iterable[Tuple[torch.Tensor, ...]]] +NPDataset = Union[Iterable[numpy.ndarray], Iterable[Tuple[numpy.ndarray, ...]]] + + +def convert_torch_tensor_or_numpy_array_to_numpy_array( + torch_tensor_or_numpy_array: Union[torch.Tensor, numpy.ndarray] +) -> numpy.ndarray: + """Convert a torch tensor or a numpy array to a numpy array. + + Args: + torch_tensor_or_numpy_array (Union[torch.Tensor, numpy.ndarray]): the value that is either + a torch tensor or a numpy array. + + Returns: + numpy.ndarray: the value converted to a numpy array. + """ + return ( + torch_tensor_or_numpy_array + if isinstance(torch_tensor_or_numpy_array, numpy.ndarray) + else torch_tensor_or_numpy_array.cpu().numpy() + ) + def compile_torch_model( torch_model: torch.nn.Module, - torch_inputset: torch.FloatTensor, + torch_inputset: Union[TorchDataset, NPDataset], compilation_configuration: Optional[CompilationConfiguration] = None, compilation_artifacts: Optional[CompilationArtifacts] = None, show_mlir: bool = False, n_bits=7, -): +) -> QuantizedModule: """Take a model in torch, turn it to numpy, transform weights to integer. Later, we'll compile the integer model. Args: torch_model (torch.nn.Module): the model to quantize, - torch_inputset (torch.FloatTensor): the inputset, in torch form + torch_inputset (Union[TorchDataset, NPDataset]): the inputset, can contain either torch + tensors or numpy.ndarray or tuples of those for networks requiring multiple inputs + function_parameters_encrypted_status (Dict[str, Union[str, EncryptedStatus]]): a dict with + the name of the parameter and its encrypted status compilation_configuration (CompilationConfiguration): Configuration object to use during compilation compilation_artifacts (CompilationArtifacts): Artifacts object to fill @@ -33,35 +58,37 @@ def compile_torch_model( to be sent to the compiler backend is shown on the screen, e.g., for debugging or demo n_bits: the number of bits for the quantization + Returns: + QuantizedModule: The resulting compiled QuantizedModule. """ # Create corresponding numpy model numpy_model = NumpyModule(torch_model) # Torch input to numpy - numpy_inputset = numpy.array( - [ - tuple(val.cpu().numpy() for val in input_) - if isinstance(input_, tuple) - else tuple(input_.cpu().numpy()) - for input_ in torch_inputset - ] + numpy_inputset = ( + tuple(convert_torch_tensor_or_numpy_array_to_numpy_array(val) for val in input_) + if isinstance(input_, tuple) + else convert_torch_tensor_or_numpy_array_to_numpy_array(input_) + for input_ in torch_inputset + ) + + numpy_inputset_as_single_array = numpy.concatenate( + tuple(numpy.expand_dims(arr, 0) for arr in numpy_inputset) ) # Quantize with post-training static method, to have a model with integer weights post_training_quant = PostTrainingAffineQuantization(n_bits, numpy_model) - quantized_model = post_training_quant.quantize_module(numpy_inputset) - model_to_compile = quantized_model + quantized_module = post_training_quant.quantize_module(numpy_inputset_as_single_array) # Quantize input - quantized_numpy_inputset = QuantizedArray(n_bits, numpy_inputset) + quantized_numpy_inputset = QuantizedArray(n_bits, numpy_inputset_as_single_array) - # FIXME: just print, to avoid to have useless vars. Will be removed once we can finally compile - # the model - print( - model_to_compile, + quantized_module.compile( quantized_numpy_inputset, compilation_configuration, compilation_artifacts, show_mlir, ) + + return quantized_module diff --git a/tests/quantization/test_compilation.py b/tests/quantization/test_compilation.py index d5bf51bd8..6394d425a 100644 --- a/tests/quantization/test_compilation.py +++ b/tests/quantization/test_compilation.py @@ -85,7 +85,7 @@ def test_quantized_module_compilation( homomorphic_predictions = homomorphic_predictions.reshape(dequant_predictions.shape) # Make sure homomorphic_predictions are the same as dequant_predictions - if numpy.isclose(homomorphic_predictions.ravel(), dequant_predictions.ravel()).all(): + if numpy.isclose(homomorphic_predictions, dequant_predictions).all(): return # Bad computation after nb_tries diff --git a/tests/torch/test_compile_torch.py b/tests/torch/test_compile_torch.py index 00fc3c8f2..175e36ab7 100644 --- a/tests/torch/test_compile_torch.py +++ b/tests/torch/test_compile_torch.py @@ -1,62 +1,79 @@ """Tests for the torch to numpy module.""" +import numpy import pytest -import torch from torch import nn +from concrete.quantization import QuantizedArray from concrete.torch.compile import compile_torch_model +# INPUT_OUTPUT_FEATURE is the number of input and output of each of the network layers. +# (as well as the input of the network itself) +INPUT_OUTPUT_FEATURE = [1, 2, 3] + class FC(nn.Module): """Torch model for the tests""" - def __init__(self): + def __init__(self, input_output): super().__init__() - self.fc1 = nn.Linear(in_features=32 * 32 * 3, out_features=128) + self.fc1 = nn.Linear(in_features=input_output, out_features=input_output) self.sigmoid1 = nn.Sigmoid() - self.fc2 = nn.Linear(in_features=128, out_features=64) - self.sigmoid2 = nn.Sigmoid() - self.fc3 = nn.Linear(in_features=64, out_features=64) - self.sigmoid3 = nn.Sigmoid() - self.fc4 = nn.Linear(in_features=64, out_features=64) - self.sigmoid4 = nn.Sigmoid() - self.fc5 = nn.Linear(in_features=64, out_features=10) + self.fc2 = nn.Linear(in_features=input_output, out_features=input_output) def forward(self, x): """Forward pass.""" out = self.fc1(x) out = self.sigmoid1(out) out = self.fc2(out) - out = self.sigmoid2(out) - out = self.fc3(out) - out = self.sigmoid3(out) - out = self.fc4(out) - out = self.sigmoid4(out) - out = self.fc5(out) return out @pytest.mark.parametrize( - "model, input_shape", - [ - pytest.param(FC, (100, 32 * 32 * 3)), - ], + "model", + [pytest.param(FC)], ) -def test_compile_torch(model, input_shape, default_compilation_configuration, seed_torch): +@pytest.mark.parametrize( + "input_output_feature", + [pytest.param(input_output_feature) for input_output_feature in INPUT_OUTPUT_FEATURE], +) +def test_compile_torch(input_output_feature, model, seed_torch, default_compilation_configuration): """Test the different model architecture from torch numpy.""" # Seed torch seed_torch() - # Define the torch model - torch_fc_model = model() + n_bits = 2 + # Define an input shape (n_examples, n_features) + n_examples = 10 + + # Define the torch model + torch_fc_model = model(input_output_feature) # Create random input - torch_inputset = torch.randn(input_shape) + inputset = [numpy.random.uniform(-1, 1, size=input_output_feature) for _ in range(n_examples)] # Compile - compile_torch_model( + quantized_numpy_module = compile_torch_model( torch_fc_model, - torch_inputset, + inputset, default_compilation_configuration, + n_bits=n_bits, ) + + # Compare predictions between FHE and QuantizedModule + clear_predictions = [] + homomorphic_predictions = [] + for numpy_input in inputset: + q_input = QuantizedArray(n_bits, numpy_input) + x_q = q_input.qvalues + clear_predictions.append(quantized_numpy_module.forward(x_q)) + homomorphic_predictions.append( + quantized_numpy_module.forward_fhe.run(numpy.array([x_q]).astype(numpy.uint8)) + ) + + clear_predictions = numpy.array(clear_predictions) + homomorphic_predictions = numpy.array(homomorphic_predictions) + + # Make sure homomorphic_predictions are the same as dequant_predictions + assert numpy.array_equal(homomorphic_predictions, clear_predictions)