mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat: add quantization utilities
This commit is contained in:
@@ -1,3 +1,3 @@
|
||||
"""Package top import."""
|
||||
from . import common, numpy, torch
|
||||
from . import common, numpy, quantization, torch
|
||||
from .version import __version__
|
||||
|
||||
4
concrete/quantization/__init__.py
Normal file
4
concrete/quantization/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""Modules for quantization."""
|
||||
from .quantized_activations import QuantizedSigmoid
|
||||
from .quantized_array import QuantizedArray
|
||||
from .quantized_layers import QuantizedLinear
|
||||
87
concrete/quantization/quantized_activations.py
Normal file
87
concrete/quantization/quantized_activations.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""Quantized activation functions."""
|
||||
import copy
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
import numpy
|
||||
|
||||
from .quantized_array import QuantizedArray
|
||||
|
||||
|
||||
class QuantizedActivation(ABC):
|
||||
"""Base class for quantized activation function."""
|
||||
|
||||
q_out: Optional[QuantizedArray]
|
||||
|
||||
def __init__(self, n_bits) -> None:
|
||||
self.n_bits = n_bits
|
||||
self.q_out = None
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, q_input: QuantizedArray) -> QuantizedArray:
|
||||
"""Execute the forward pass."""
|
||||
|
||||
@abstractmethod
|
||||
def calibrate(self, x: numpy.ndarray) -> None:
|
||||
"""Create corresponding QuantizedArray for the output of the activation function.
|
||||
|
||||
Args:
|
||||
x (numpy.ndarray): Inputs.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def dequant_input(q_input: QuantizedArray) -> numpy.ndarray:
|
||||
"""Dequantize the input of the activation function.
|
||||
|
||||
Args:
|
||||
q_input (QuantizedArray): Quantized array for the inputs
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: Return dequantized input in a numpy array
|
||||
"""
|
||||
return (q_input.qvalues - q_input.zero_point) * q_input.scale
|
||||
|
||||
def quant_output(self, qoutput_activation: numpy.ndarray) -> QuantizedArray:
|
||||
"""Quantize the output of the activation function.
|
||||
|
||||
Args:
|
||||
q_out (numpy.ndarray): Output of the activation function.
|
||||
|
||||
Returns:
|
||||
QuantizedArray: Quantized output.
|
||||
"""
|
||||
assert self.q_out is not None
|
||||
|
||||
qoutput_activation = qoutput_activation / self.q_out.scale + self.q_out.zero_point
|
||||
qoutput_activation = (
|
||||
(qoutput_activation).round().clip(0, 2 ** self.q_out.n_bits - 1).astype(int)
|
||||
)
|
||||
|
||||
# TODO find a better way to do the following (see issue #832)
|
||||
q_out = copy.copy(self.q_out)
|
||||
q_out.update_qvalues(qoutput_activation)
|
||||
return q_out
|
||||
|
||||
|
||||
class QuantizedSigmoid(QuantizedActivation):
|
||||
"""Quantized sigmoid activation function."""
|
||||
|
||||
def calibrate(self, x: numpy.ndarray):
|
||||
self.q_out = QuantizedArray(self.n_bits, 1 / (1 + numpy.exp(-x)))
|
||||
|
||||
def __call__(self, q_input: QuantizedArray) -> QuantizedArray:
|
||||
"""Process the forward pass of the quantized sigmoid.
|
||||
|
||||
Args:
|
||||
q_input (QuantizedArray): Quantized input.
|
||||
|
||||
Returns:
|
||||
q_out (QuantizedArray): Quantized output.
|
||||
"""
|
||||
|
||||
quant_sigmoid = self.dequant_input(q_input)
|
||||
quant_sigmoid = 1 + numpy.exp(-quant_sigmoid)
|
||||
quant_sigmoid = 1 / quant_sigmoid
|
||||
|
||||
q_out = self.quant_output(quant_sigmoid)
|
||||
return q_out
|
||||
95
concrete/quantization/quantized_array.py
Normal file
95
concrete/quantization/quantized_array.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""Quantization utilities for a numpy array/tensor."""
|
||||
from copy import deepcopy
|
||||
from typing import Optional
|
||||
|
||||
import numpy
|
||||
|
||||
STABILITY_CONST = 10 ** -12
|
||||
|
||||
|
||||
class QuantizedArray:
|
||||
"""Abstraction of quantized array."""
|
||||
|
||||
def __init__(self, n_bits: int, values: numpy.ndarray):
|
||||
"""Quantize an array.
|
||||
|
||||
See https://arxiv.org/abs/1712.05877.
|
||||
|
||||
Args:
|
||||
values (numpy.ndarray): Values to be quantized.
|
||||
n_bits (int): The number of bits to use for quantization. Defaults to 7.
|
||||
"""
|
||||
|
||||
self.values = values
|
||||
self.n_bits = n_bits
|
||||
self.scale, self.zero_point, self.qvalues = self.compute_quantization_parameters()
|
||||
|
||||
def __call__(self) -> Optional[numpy.ndarray]:
|
||||
return self.qvalues
|
||||
|
||||
def compute_quantization_parameters(self):
|
||||
"""Compute the quantization parameters."""
|
||||
# Small constant needed for stability
|
||||
rmax = numpy.max(self.values) + STABILITY_CONST
|
||||
rmin = numpy.min(self.values)
|
||||
scale = (rmax - rmin) / (2 ** self.n_bits - 1) if rmax != rmin else 1.0
|
||||
|
||||
zero_point = numpy.round(-(rmin / scale)).astype(int)
|
||||
|
||||
# Compute quantized values and store
|
||||
qvalues = self.values / scale + zero_point
|
||||
qvalues = (
|
||||
qvalues.round()
|
||||
.clip(0, 2 ** self.n_bits - 1)
|
||||
.astype(int) # Careful this can be very large with high number of bits
|
||||
)
|
||||
return scale, zero_point, qvalues
|
||||
|
||||
def update_values(self, values: numpy.ndarray) -> Optional[numpy.ndarray]:
|
||||
"""Update values to get their corresponding qvalues using the related quantized parameters.
|
||||
|
||||
Args:
|
||||
values (numpy.ndarray): Values to replace self.values
|
||||
|
||||
Returns:
|
||||
qvalues (numpy.ndarray): Corresponding qvalues
|
||||
"""
|
||||
self.values = deepcopy(values)
|
||||
self.quant()
|
||||
return self.qvalues
|
||||
|
||||
def update_qvalues(self, qvalues: numpy.ndarray) -> Optional[numpy.ndarray]:
|
||||
"""Update qvalues to get their corresponding values using the related quantized parameters.
|
||||
|
||||
Args:
|
||||
qvalues (numpy.ndarray): Values to replace self.qvalues
|
||||
|
||||
Returns:
|
||||
values (numpy.ndarray): Corresponding values
|
||||
"""
|
||||
self.qvalues = deepcopy(qvalues)
|
||||
self.dequant()
|
||||
return self.values
|
||||
|
||||
def quant(self) -> Optional[numpy.ndarray]:
|
||||
"""Quantize self.values.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: Quantized values.
|
||||
"""
|
||||
self.qvalues = (
|
||||
(self.values / self.scale + self.zero_point)
|
||||
.round()
|
||||
.clip(0, 2 ** self.n_bits - 1)
|
||||
.astype(int)
|
||||
)
|
||||
return self.qvalues
|
||||
|
||||
def dequant(self) -> numpy.ndarray:
|
||||
"""Dequantize self.qvalues.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: Dequantized values.
|
||||
"""
|
||||
self.values = self.scale * (self.qvalues - self.zero_point)
|
||||
return self.values
|
||||
80
concrete/quantization/quantized_layers.py
Normal file
80
concrete/quantization/quantized_layers.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""Quantized layers."""
|
||||
import copy
|
||||
from typing import Optional
|
||||
|
||||
import numpy
|
||||
|
||||
from .quantized_array import QuantizedArray
|
||||
|
||||
|
||||
class QuantizedLinear:
|
||||
"""Fully connected quantized layer."""
|
||||
|
||||
q_out: Optional[QuantizedArray]
|
||||
|
||||
def __init__(
|
||||
self, n_bits: int, q_weights: QuantizedArray, q_bias: Optional[QuantizedArray] = None
|
||||
):
|
||||
"""Implement the forward pass of a quantized linear layer.
|
||||
|
||||
Note: QuantizedLinear seems to become unstable when n_bits > 23.
|
||||
|
||||
Args:
|
||||
n_bits (int): Maximum number of bits for the ouput.
|
||||
q_weights (QuantizedArray): Quantized weights (n_examples, n_neurons, n_features).
|
||||
q_bias (QuantizedArray, optional): Quantized bias (n_neurons). Defaults to None.
|
||||
"""
|
||||
self.q_weights = q_weights
|
||||
self.q_bias = q_bias
|
||||
self.n_bits = n_bits
|
||||
|
||||
if self.q_bias is None:
|
||||
self.q_bias = QuantizedArray(n_bits, numpy.zeros(self.q_weights.values.shape[:-1]))
|
||||
self.q_out = None
|
||||
|
||||
def calibrate(self, x: numpy.ndarray):
|
||||
"""Create corresponding QuantizedArray for the output of QuantizedLinear.
|
||||
|
||||
Args:
|
||||
x (numpy.ndarray): Inputs.
|
||||
"""
|
||||
assert self.q_bias is not None
|
||||
self.q_out = QuantizedArray(self.n_bits, x @ self.q_weights.values.T + self.q_bias.values)
|
||||
|
||||
def __call__(self, q_input: QuantizedArray) -> QuantizedArray:
|
||||
"""Process the forward pass of the quantized linear layer.
|
||||
|
||||
Note: in standard quantization, floats are problematics as quantization
|
||||
targets a specific integer only hardware. However in FHE, we can create a table lookup
|
||||
to bypass this problem. Thus we leave the floats as is.
|
||||
Args:
|
||||
q_input (QuantizedArray): Quantized input.
|
||||
|
||||
Returns:
|
||||
q_out_ (QuantizedArray): Quantized output.
|
||||
"""
|
||||
# Satisfy mypy.
|
||||
assert self.q_out is not None
|
||||
assert self.q_bias is not None
|
||||
# We need to develop the following equation to have the main computation
|
||||
# (self.q_weights.q_values @ self.q_inputs.q_values) without zero_point values.
|
||||
# See https://github.com/google/gemmlowp/blob/master/doc/quantization.md #852
|
||||
|
||||
m_product = (q_input.scale * self.q_weights.scale) / (self.q_out.scale)
|
||||
dot_product = (q_input.qvalues - q_input.zero_point) @ (
|
||||
self.q_weights.qvalues - self.q_weights.zero_point
|
||||
).T
|
||||
|
||||
m_bias = self.q_bias.scale / (q_input.scale * self.q_weights.scale)
|
||||
bias_part = m_bias * (self.q_bias.qvalues - self.q_bias.zero_point)
|
||||
numpy_q_out = m_product * (dot_product + bias_part) + self.q_out.zero_point
|
||||
|
||||
numpy_q_out = numpy_q_out.round().clip(0, 2 ** self.q_out.n_bits - 1).astype(int)
|
||||
|
||||
# TODO find a more intuitive way to do the following (see issue #832)
|
||||
# We should be able to reuse q_out quantization parameters
|
||||
# easily to get a new QuantizedArray
|
||||
q_out_ = copy.copy(self.q_out)
|
||||
q_out_.update_qvalues(numpy_q_out)
|
||||
|
||||
return q_out_
|
||||
@@ -1,13 +1,12 @@
|
||||
"""A torch to numpy module."""
|
||||
import numpy
|
||||
from numpy.typing import ArrayLike
|
||||
from torch import nn
|
||||
|
||||
|
||||
class NumpyModule:
|
||||
"""General interface to transform a torch.nn.Module to numpy module."""
|
||||
|
||||
IMPLEMENTED_MODULES = [nn.Linear, nn.Sigmoid]
|
||||
IMPLEMENTED_MODULES = {nn.Linear, nn.Sigmoid}
|
||||
|
||||
def __init__(self, torch_model: nn.Module):
|
||||
"""Initialize our numpy module.
|
||||
@@ -22,8 +21,21 @@ class NumpyModule:
|
||||
torch_model (nn.Module): A fully trained, torch model alond with its parameters.
|
||||
"""
|
||||
self.torch_model = torch_model
|
||||
self.check_compatibility()
|
||||
self.convert_to_numpy()
|
||||
|
||||
def check_compatibility(self):
|
||||
"""Check the compatibility of all layers in the torch model."""
|
||||
|
||||
for _, layer in self.torch_model.named_children():
|
||||
if (layer_type := type(layer)) not in self.IMPLEMENTED_MODULES:
|
||||
raise ValueError(
|
||||
f"The following module is currently not implemented: {layer_type.__name__}. "
|
||||
f"Please stick to the available torch modules: "
|
||||
f"{', '.join(sorted(module.__name__ for module in self.IMPLEMENTED_MODULES))}."
|
||||
)
|
||||
return True
|
||||
|
||||
def convert_to_numpy(self):
|
||||
"""Transform all parameters from torch tensor to numpy arrays."""
|
||||
self.numpy_module_dict = {}
|
||||
@@ -33,11 +45,11 @@ class NumpyModule:
|
||||
params = weights.detach().numpy()
|
||||
self.numpy_module_dict[name] = params
|
||||
|
||||
def __call__(self, x: ArrayLike):
|
||||
def __call__(self, x: numpy.ndarray):
|
||||
"""Return the function to be compiled by concretefhe.numpy."""
|
||||
return self.forward(x)
|
||||
|
||||
def forward(self, x: ArrayLike) -> ArrayLike:
|
||||
def forward(self, x: numpy.ndarray) -> numpy.ndarray:
|
||||
"""Apply a forward pass with numpy function only.
|
||||
|
||||
Args:
|
||||
@@ -56,14 +68,6 @@ class NumpyModule:
|
||||
+ self.numpy_module_dict[f"{name}.bias"]
|
||||
)
|
||||
elif isinstance(layer, nn.Sigmoid):
|
||||
# concrete currently does not accept the "-" python operator
|
||||
# hence the use of numpy.negative which is supported.
|
||||
x = 1 / (1 + numpy.exp(numpy.negative(x)))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The follwing module is currently not implemented: {type(layer).__name__}"
|
||||
f"Please stick to the available torch modules:"
|
||||
f"{', '.join([module.__name__ for module in self.IMPLEMENTED_MODULES])}."
|
||||
)
|
||||
x = 1 / (1 + numpy.exp(-x))
|
||||
|
||||
return x
|
||||
|
||||
42
tests/quantization/test_quantized_activations.py
Normal file
42
tests/quantization/test_quantized_activations.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""Tests for the quantized activation functions."""
|
||||
import numpy
|
||||
import pytest
|
||||
|
||||
from concrete.quantization import QuantizedArray, QuantizedSigmoid
|
||||
|
||||
N_BITS_ATOL_TUPLE_LIST = [
|
||||
(32, 10 ** -2),
|
||||
(28, 10 ** -2),
|
||||
(20, 10 ** -2),
|
||||
(16, 10 ** -1),
|
||||
(8, 10 ** -0),
|
||||
(4, 10 ** -0),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"n_bits, atol",
|
||||
[pytest.param(n_bits, atol) for n_bits, atol in N_BITS_ATOL_TUPLE_LIST],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"quant_activation, values",
|
||||
[pytest.param(QuantizedSigmoid, numpy.random.uniform(size=(10, 40, 20)))],
|
||||
)
|
||||
def test_activations(quant_activation, values, n_bits, atol):
|
||||
"""Test activation functions."""
|
||||
q_inputs = QuantizedArray(n_bits, values)
|
||||
quant_sigmoid = quant_activation(n_bits)
|
||||
quant_sigmoid.calibrate(values)
|
||||
expected_output = quant_sigmoid.q_out.values
|
||||
q_output = quant_sigmoid(q_inputs)
|
||||
qvalues = q_output.qvalues
|
||||
|
||||
# Quantized values must be contained between 0 and 2**n_bits - 1.
|
||||
assert numpy.max(qvalues) <= 2 ** n_bits - 1
|
||||
assert numpy.min(qvalues) >= 0
|
||||
|
||||
# Dequantized values must be close to original values
|
||||
dequant_values = q_output.dequant()
|
||||
|
||||
# Check that all values are close
|
||||
assert numpy.isclose(dequant_values, expected_output, atol=atol).all()
|
||||
53
tests/quantization/test_quantized_array.py
Normal file
53
tests/quantization/test_quantized_array.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""Tests for the quantized array/tensors."""
|
||||
import numpy
|
||||
import pytest
|
||||
|
||||
from concrete.quantization import QuantizedArray
|
||||
|
||||
N_BITS_ATOL_TUPLE_LIST = [
|
||||
(32, 10 ** -2),
|
||||
(28, 10 ** -2),
|
||||
(20, 10 ** -2),
|
||||
(16, 10 ** -1),
|
||||
(8, 10 ** -0),
|
||||
(4, 10 ** -0),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"n_bits, atol",
|
||||
[pytest.param(n_bits, atol) for n_bits, atol in N_BITS_ATOL_TUPLE_LIST],
|
||||
)
|
||||
@pytest.mark.parametrize("values", [pytest.param(numpy.random.randn(2000))])
|
||||
def test_quant_dequant_update(values, n_bits, atol):
|
||||
"""Test the quant and dequant function."""
|
||||
|
||||
quant_array = QuantizedArray(n_bits, values)
|
||||
qvalues = quant_array.quant()
|
||||
|
||||
# Quantized values must be contained between 0 and 2**n_bits
|
||||
assert numpy.max(qvalues) <= 2 ** n_bits - 1
|
||||
assert numpy.min(qvalues) >= 0
|
||||
|
||||
# Dequantized values must be close to original values
|
||||
dequant_values = quant_array.dequant()
|
||||
|
||||
# Check that all values are close
|
||||
assert numpy.isclose(dequant_values, values, atol=atol).all()
|
||||
|
||||
# Test update functions
|
||||
new_values = numpy.array([0.3, 0.5, -1.2, -3.4])
|
||||
new_qvalues_ = quant_array.update_values(new_values)
|
||||
|
||||
# Make sure the shape changed for the qvalues
|
||||
assert new_qvalues_.shape != qvalues.shape
|
||||
|
||||
new_qvalues = numpy.array([1, 4, 7, 29])
|
||||
new_values_updated = quant_array.update_qvalues(new_qvalues)
|
||||
|
||||
# Make sure that we can see at least one change.
|
||||
assert not numpy.array_equal(new_qvalues, new_qvalues_)
|
||||
assert not numpy.array_equal(new_values, new_values_updated)
|
||||
|
||||
# Check that the __call__ returns also the qvalues.
|
||||
assert numpy.array_equal(quant_array(), new_qvalues)
|
||||
58
tests/quantization/test_quantized_layers.py
Normal file
58
tests/quantization/test_quantized_layers.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""Tests for the quantized layers."""
|
||||
import numpy
|
||||
import pytest
|
||||
|
||||
from concrete.quantization import QuantizedArray, QuantizedLinear
|
||||
|
||||
# QuantizedLinear unstable with n_bits>23.
|
||||
N_BITS_LIST = [20, 16, 8, 4]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"n_bits",
|
||||
[pytest.param(n_bits) for n_bits in N_BITS_LIST],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"n_examples, n_features, n_neurons",
|
||||
[
|
||||
pytest.param(20, 500, 30),
|
||||
pytest.param(200, 300, 50),
|
||||
pytest.param(10000, 100, 1),
|
||||
pytest.param(10, 20, 1),
|
||||
],
|
||||
)
|
||||
def test_quantized_linear(n_examples, n_features, n_neurons, n_bits):
|
||||
"""Test the quantization linear layer of numpy.array.
|
||||
|
||||
With n_bits>>0 we expect the results of the quantized linear
|
||||
to be the same as the standard linear layer.
|
||||
"""
|
||||
inputs = numpy.random.uniform(size=(n_examples, n_features))
|
||||
q_inputs = QuantizedArray(n_bits, inputs)
|
||||
|
||||
# shape of weights: (n_examples, n_features, n_neurons)
|
||||
weights = numpy.random.uniform(size=(n_neurons, n_features))
|
||||
q_weights = QuantizedArray(n_bits, weights)
|
||||
|
||||
bias = numpy.random.uniform(size=(n_neurons))
|
||||
q_bias = QuantizedArray(n_bits, bias)
|
||||
|
||||
# Define our QuantizedLinear layer
|
||||
q_linear = QuantizedLinear(n_bits, q_weights, q_bias)
|
||||
|
||||
# Calibrate the Quantized layer
|
||||
q_linear.calibrate(inputs)
|
||||
expected_outputs = q_linear.q_out.values
|
||||
actual_output = q_linear(q_inputs).dequant()
|
||||
|
||||
assert numpy.isclose(expected_outputs, actual_output, rtol=10 ** -1).all()
|
||||
|
||||
# Same test without bias
|
||||
q_linear = QuantizedLinear(n_bits, q_weights)
|
||||
|
||||
# Calibrate the Quantized layer
|
||||
q_linear.calibrate(inputs)
|
||||
expected_outputs = q_linear.q_out.values
|
||||
actual_output = q_linear(q_inputs).dequant()
|
||||
|
||||
assert numpy.isclose(expected_outputs, actual_output, rtol=10 ** -1).all()
|
||||
@@ -64,19 +64,54 @@ class FC(nn.Module):
|
||||
"model, input_shape",
|
||||
[
|
||||
pytest.param(FC, (100, 32 * 32 * 3)),
|
||||
pytest.param(CNN, (100, 3, 32, 32), marks=pytest.mark.xfail(strict=True)),
|
||||
],
|
||||
)
|
||||
def test_torch_to_numpy(model, input_shape):
|
||||
"""Test the different model architecture from torch numpy."""
|
||||
|
||||
# Define the torch model
|
||||
torch_fc_model = model()
|
||||
torch_input = torch.randn(input_shape)
|
||||
torch_predictions = torch_fc_model(torch_input).detach().numpy()
|
||||
# Create random input
|
||||
torch_input_1 = torch.randn(input_shape)
|
||||
# Predict with torch model
|
||||
torch_predictions = torch_fc_model(torch_input_1).detach().numpy()
|
||||
# Create corresponding numpy model
|
||||
numpy_fc_model = NumpyModule(torch_fc_model)
|
||||
# torch_input to numpy.
|
||||
numpy_input = torch_input.detach().numpy()
|
||||
numpy_predictions = numpy_fc_model(numpy_input)
|
||||
# Torch input to numpy
|
||||
numpy_input_1 = torch_input_1.detach().numpy()
|
||||
# Predict with numpy model
|
||||
numpy_predictions = numpy_fc_model(numpy_input_1)
|
||||
|
||||
# Test: the output of the numpy model is the same as the torch model.
|
||||
assert numpy_predictions.shape == torch_predictions.shape
|
||||
# Test: prediction from the numpy model are the same as the torh model.
|
||||
assert numpy.isclose(torch_predictions, numpy_predictions, rtol=10 - 3).all()
|
||||
|
||||
# Test: dynamics between layers is working (quantized input and activations)
|
||||
torch_input_2 = torch.randn(input_shape)
|
||||
# Make sure both inputs are different
|
||||
assert (torch_input_1 != torch_input_2).any()
|
||||
# Predict with torch
|
||||
torch_predictions = torch_fc_model(torch_input_2).detach().numpy()
|
||||
# Torch input to numpy
|
||||
numpy_input_2 = torch_input_2.detach().numpy()
|
||||
# Numpy predictions using the previous model
|
||||
numpy_predictions = numpy_fc_model(numpy_input_2)
|
||||
assert numpy.isclose(torch_predictions, numpy_predictions, rtol=10 - 3).all()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model, incompatible_layer",
|
||||
[pytest.param(CNN, "Conv2d")],
|
||||
)
|
||||
def test_raises(model, incompatible_layer):
|
||||
"""Function to test incompatible layers."""
|
||||
|
||||
torch_incompatible_model = model()
|
||||
expected_errmsg = (
|
||||
f"The following module is currently not implemented: {incompatible_layer}. "
|
||||
f"Please stick to the available torch modules: "
|
||||
f"{', '.join(sorted(module.__name__ for module in NumpyModule.IMPLEMENTED_MODULES))}."
|
||||
)
|
||||
with pytest.raises(ValueError, match=expected_errmsg):
|
||||
NumpyModule(torch_incompatible_model)
|
||||
|
||||
Reference in New Issue
Block a user