mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
committed by
Benoit Chevallier
parent
e111891a41
commit
edefde9189
67
concrete/torch/compile.py
Normal file
67
concrete/torch/compile.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""torch compilation function."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
|
||||
from ..common.compilation import CompilationArtifacts, CompilationConfiguration
|
||||
from ..quantization import PostTrainingAffineQuantization, QuantizedArray
|
||||
from . import NumpyModule
|
||||
|
||||
|
||||
def compile_torch_model(
|
||||
torch_model: torch.nn.Module,
|
||||
torch_inputset: torch.FloatTensor,
|
||||
compilation_configuration: Optional[CompilationConfiguration] = None,
|
||||
compilation_artifacts: Optional[CompilationArtifacts] = None,
|
||||
show_mlir: bool = False,
|
||||
n_bits=7,
|
||||
):
|
||||
"""Take a model in torch, turn it to numpy, transform weights to integer.
|
||||
|
||||
Later, we'll compile the integer model.
|
||||
|
||||
Args:
|
||||
torch_model (torch.nn.Module): the model to quantize,
|
||||
torch_inputset (torch.FloatTensor): the inputset, in torch form
|
||||
compilation_configuration (CompilationConfiguration): Configuration object to use
|
||||
during compilation
|
||||
compilation_artifacts (CompilationArtifacts): Artifacts object to fill
|
||||
during compilation
|
||||
show_mlir (bool): if set, the MLIR produced by the converter and which is going
|
||||
to be sent to the compiler backend is shown on the screen, e.g., for debugging or demo
|
||||
n_bits: the number of bits for the quantization
|
||||
|
||||
"""
|
||||
|
||||
# Create corresponding numpy model
|
||||
numpy_model = NumpyModule(torch_model)
|
||||
|
||||
# Torch input to numpy
|
||||
numpy_inputset = numpy.array(
|
||||
[
|
||||
tuple(val.cpu().numpy() for val in input_)
|
||||
if isinstance(input_, tuple)
|
||||
else tuple(input_.cpu().numpy())
|
||||
for input_ in torch_inputset
|
||||
]
|
||||
)
|
||||
|
||||
# Quantize with post-training static method, to have a model with integer weights
|
||||
post_training_quant = PostTrainingAffineQuantization(n_bits, numpy_model)
|
||||
quantized_model = post_training_quant.quantize_module(numpy_inputset)
|
||||
model_to_compile = quantized_model
|
||||
|
||||
# Quantize input
|
||||
quantized_numpy_inputset = QuantizedArray(n_bits, numpy_inputset)
|
||||
|
||||
# FIXME: just print, to avoid to have useless vars. Will be removed once we can finally compile
|
||||
# the model
|
||||
print(
|
||||
model_to_compile,
|
||||
quantized_numpy_inputset,
|
||||
compilation_configuration,
|
||||
compilation_artifacts,
|
||||
show_mlir,
|
||||
)
|
||||
62
tests/torch/test_compile_torch.py
Normal file
62
tests/torch/test_compile_torch.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""Tests for the torch to numpy module."""
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from concrete.torch.compile import compile_torch_model
|
||||
|
||||
|
||||
class FC(nn.Module):
|
||||
"""Torch model for the tests"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(in_features=32 * 32 * 3, out_features=128)
|
||||
self.sigmoid1 = nn.Sigmoid()
|
||||
self.fc2 = nn.Linear(in_features=128, out_features=64)
|
||||
self.sigmoid2 = nn.Sigmoid()
|
||||
self.fc3 = nn.Linear(in_features=64, out_features=64)
|
||||
self.sigmoid3 = nn.Sigmoid()
|
||||
self.fc4 = nn.Linear(in_features=64, out_features=64)
|
||||
self.sigmoid4 = nn.Sigmoid()
|
||||
self.fc5 = nn.Linear(in_features=64, out_features=10)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass."""
|
||||
out = self.fc1(x)
|
||||
out = self.sigmoid1(out)
|
||||
out = self.fc2(out)
|
||||
out = self.sigmoid2(out)
|
||||
out = self.fc3(out)
|
||||
out = self.sigmoid3(out)
|
||||
out = self.fc4(out)
|
||||
out = self.sigmoid4(out)
|
||||
out = self.fc5(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model, input_shape",
|
||||
[
|
||||
pytest.param(FC, (100, 32 * 32 * 3)),
|
||||
],
|
||||
)
|
||||
def test_compile_torch(model, input_shape, default_compilation_configuration, seed_torch):
|
||||
"""Test the different model architecture from torch numpy."""
|
||||
|
||||
# Seed torch
|
||||
seed_torch()
|
||||
|
||||
# Define the torch model
|
||||
torch_fc_model = model()
|
||||
|
||||
# Create random input
|
||||
torch_inputset = torch.randn(input_shape)
|
||||
|
||||
# Compile
|
||||
compile_torch_model(
|
||||
torch_fc_model,
|
||||
torch_inputset,
|
||||
default_compilation_configuration,
|
||||
)
|
||||
Reference in New Issue
Block a user