feat: adding a compile_torch_model function

refs #898
This commit is contained in:
Benoit Chevallier-Mames
2021-11-23 18:40:08 +01:00
committed by Benoit Chevallier
parent e111891a41
commit edefde9189
2 changed files with 129 additions and 0 deletions

67
concrete/torch/compile.py Normal file
View File

@@ -0,0 +1,67 @@
"""torch compilation function."""
from typing import Optional
import numpy
import torch
from ..common.compilation import CompilationArtifacts, CompilationConfiguration
from ..quantization import PostTrainingAffineQuantization, QuantizedArray
from . import NumpyModule
def compile_torch_model(
torch_model: torch.nn.Module,
torch_inputset: torch.FloatTensor,
compilation_configuration: Optional[CompilationConfiguration] = None,
compilation_artifacts: Optional[CompilationArtifacts] = None,
show_mlir: bool = False,
n_bits=7,
):
"""Take a model in torch, turn it to numpy, transform weights to integer.
Later, we'll compile the integer model.
Args:
torch_model (torch.nn.Module): the model to quantize,
torch_inputset (torch.FloatTensor): the inputset, in torch form
compilation_configuration (CompilationConfiguration): Configuration object to use
during compilation
compilation_artifacts (CompilationArtifacts): Artifacts object to fill
during compilation
show_mlir (bool): if set, the MLIR produced by the converter and which is going
to be sent to the compiler backend is shown on the screen, e.g., for debugging or demo
n_bits: the number of bits for the quantization
"""
# Create corresponding numpy model
numpy_model = NumpyModule(torch_model)
# Torch input to numpy
numpy_inputset = numpy.array(
[
tuple(val.cpu().numpy() for val in input_)
if isinstance(input_, tuple)
else tuple(input_.cpu().numpy())
for input_ in torch_inputset
]
)
# Quantize with post-training static method, to have a model with integer weights
post_training_quant = PostTrainingAffineQuantization(n_bits, numpy_model)
quantized_model = post_training_quant.quantize_module(numpy_inputset)
model_to_compile = quantized_model
# Quantize input
quantized_numpy_inputset = QuantizedArray(n_bits, numpy_inputset)
# FIXME: just print, to avoid to have useless vars. Will be removed once we can finally compile
# the model
print(
model_to_compile,
quantized_numpy_inputset,
compilation_configuration,
compilation_artifacts,
show_mlir,
)

View File

@@ -0,0 +1,62 @@
"""Tests for the torch to numpy module."""
import pytest
import torch
from torch import nn
from concrete.torch.compile import compile_torch_model
class FC(nn.Module):
"""Torch model for the tests"""
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(in_features=32 * 32 * 3, out_features=128)
self.sigmoid1 = nn.Sigmoid()
self.fc2 = nn.Linear(in_features=128, out_features=64)
self.sigmoid2 = nn.Sigmoid()
self.fc3 = nn.Linear(in_features=64, out_features=64)
self.sigmoid3 = nn.Sigmoid()
self.fc4 = nn.Linear(in_features=64, out_features=64)
self.sigmoid4 = nn.Sigmoid()
self.fc5 = nn.Linear(in_features=64, out_features=10)
def forward(self, x):
"""Forward pass."""
out = self.fc1(x)
out = self.sigmoid1(out)
out = self.fc2(out)
out = self.sigmoid2(out)
out = self.fc3(out)
out = self.sigmoid3(out)
out = self.fc4(out)
out = self.sigmoid4(out)
out = self.fc5(out)
return out
@pytest.mark.parametrize(
"model, input_shape",
[
pytest.param(FC, (100, 32 * 32 * 3)),
],
)
def test_compile_torch(model, input_shape, default_compilation_configuration, seed_torch):
"""Test the different model architecture from torch numpy."""
# Seed torch
seed_torch()
# Define the torch model
torch_fc_model = model()
# Create random input
torch_inputset = torch.randn(input_shape)
# Compile
compile_torch_model(
torch_fc_model,
torch_inputset,
default_compilation_configuration,
)