mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat: conversion from torch.nn.Module to numpy.
This commit is contained in:
@@ -1,3 +1,3 @@
|
||||
"""Package top import."""
|
||||
from . import common, numpy
|
||||
from . import common, numpy, torch
|
||||
from .version import __version__
|
||||
|
||||
2
concrete/torch/__init__.py
Normal file
2
concrete/torch/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""Modules for torch to numpy conversion."""
|
||||
from .numpy_module import NumpyModule
|
||||
69
concrete/torch/numpy_module.py
Normal file
69
concrete/torch/numpy_module.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""A torch to numpy module."""
|
||||
import numpy
|
||||
from numpy.typing import ArrayLike
|
||||
from torch import nn
|
||||
|
||||
|
||||
class NumpyModule:
|
||||
"""General interface to transform a torch.nn.Module to numpy module."""
|
||||
|
||||
IMPLEMENTED_MODULES = [nn.Linear, nn.Sigmoid]
|
||||
|
||||
def __init__(self, torch_model: nn.Module):
|
||||
"""Initialize our numpy module.
|
||||
|
||||
Current constraint: All objects used in the forward have to be defined in the
|
||||
__init__() of torch.nn.Module and follow the exact same order.
|
||||
(i.e. each linear layer must have one variable defined in the
|
||||
right order). This constraint will disappear when
|
||||
TorchScript is in place. (issue #818)
|
||||
|
||||
Args:
|
||||
torch_model (nn.Module): A fully trained, torch model alond with its parameters.
|
||||
"""
|
||||
self.torch_model = torch_model
|
||||
self.convert_to_numpy()
|
||||
|
||||
def convert_to_numpy(self):
|
||||
"""Transform all parameters from torch tensor to numpy arrays."""
|
||||
self.numpy_module_dict = {}
|
||||
self.numpy_module_quant_dict = {}
|
||||
|
||||
for name, weights in self.torch_model.state_dict().items():
|
||||
params = weights.detach().numpy()
|
||||
self.numpy_module_dict[name] = params
|
||||
|
||||
def __call__(self, x: ArrayLike):
|
||||
"""Return the function to be compiled by concretefhe.numpy."""
|
||||
return self.forward(x)
|
||||
|
||||
def forward(self, x: ArrayLike) -> ArrayLike:
|
||||
"""Apply a forward pass with numpy function only.
|
||||
|
||||
Args:
|
||||
x (numpy.array): Input to be processed in the forward pass.
|
||||
|
||||
Returns:
|
||||
x (numpy.array): Processed input.
|
||||
"""
|
||||
|
||||
for name, layer in self.torch_model.named_children():
|
||||
|
||||
if isinstance(layer, nn.Linear):
|
||||
# Apply a matmul product and add the bias.
|
||||
x = (
|
||||
x @ self.numpy_module_dict[f"{name}.weight"].T
|
||||
+ self.numpy_module_dict[f"{name}.bias"]
|
||||
)
|
||||
elif isinstance(layer, nn.Sigmoid):
|
||||
# concrete currently does not accept the "-" python operator
|
||||
# hence the use of numpy.negative which is supported.
|
||||
x = 1 / (1 + numpy.exp(numpy.negative(x)))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The follwing module is currently not implemented: {type(layer).__name__}"
|
||||
f"Please stick to the available torch modules:"
|
||||
f"{', '.join([module.__name__ for module in self.IMPLEMENTED_MODULES])}."
|
||||
)
|
||||
|
||||
return x
|
||||
82
tests/torch/test_torch_to_numpy.py
Normal file
82
tests/torch/test_torch_to_numpy.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""Tests for the torch to numpy module."""
|
||||
import numpy
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from concrete.torch import NumpyModule
|
||||
|
||||
|
||||
class CNN(nn.Module):
|
||||
"""Torch CNN model for the tests."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(3, 6, 5)
|
||||
self.pool = nn.AvgPool2d(2, 2)
|
||||
self.conv2 = nn.Conv2d(6, 16, 5)
|
||||
self.fc1 = nn.Linear(16 * 5 * 5, 120)
|
||||
self.fc2 = nn.Linear(120, 84)
|
||||
self.fc3 = nn.Linear(84, 10)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass."""
|
||||
x = self.pool(torch.relu(self.conv1(x)))
|
||||
x = self.pool(torch.relu(self.conv2(x)))
|
||||
x = torch.flatten(x, 1)
|
||||
x = torch.relu(self.fc1(x))
|
||||
x = torch.relu(self.fc2(x))
|
||||
x = self.fc3(x)
|
||||
return x
|
||||
|
||||
|
||||
class FC(nn.Module):
|
||||
"""Torch model for the tests"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(in_features=32 * 32 * 3, out_features=128)
|
||||
self.sigmoid1 = nn.Sigmoid()
|
||||
self.fc2 = nn.Linear(in_features=128, out_features=64)
|
||||
self.sigmoid2 = nn.Sigmoid()
|
||||
self.fc3 = nn.Linear(in_features=64, out_features=64)
|
||||
self.sigmoid3 = nn.Sigmoid()
|
||||
self.fc4 = nn.Linear(in_features=64, out_features=64)
|
||||
self.sigmoid4 = nn.Sigmoid()
|
||||
self.fc5 = nn.Linear(in_features=64, out_features=10)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass."""
|
||||
out = self.fc1(x)
|
||||
out = self.sigmoid1(out)
|
||||
out = self.fc2(out)
|
||||
out = self.sigmoid2(out)
|
||||
out = self.fc3(out)
|
||||
out = self.sigmoid3(out)
|
||||
out = self.fc4(out)
|
||||
out = self.sigmoid4(out)
|
||||
out = self.fc5(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model, input_shape",
|
||||
[
|
||||
pytest.param(FC, (100, 32 * 32 * 3)),
|
||||
pytest.param(CNN, (100, 3, 32, 32), marks=pytest.mark.xfail(strict=True)),
|
||||
],
|
||||
)
|
||||
def test_torch_to_numpy(model, input_shape):
|
||||
"""Test the different model architecture from torch numpy."""
|
||||
|
||||
torch_fc_model = model()
|
||||
torch_input = torch.randn(input_shape)
|
||||
torch_predictions = torch_fc_model(torch_input).detach().numpy()
|
||||
numpy_fc_model = NumpyModule(torch_fc_model)
|
||||
# torch_input to numpy.
|
||||
numpy_input = torch_input.detach().numpy()
|
||||
numpy_predictions = numpy_fc_model(numpy_input)
|
||||
|
||||
assert numpy_predictions.shape == torch_predictions.shape
|
||||
assert numpy.isclose(torch_predictions, numpy_predictions, rtol=10 - 3).all()
|
||||
Reference in New Issue
Block a user