Files
concrete/tests/quantization/test_compilation.py

88 lines
2.7 KiB
Python

"""Test Neural Networks compilations"""
import numpy
import pytest
from torch import nn
from concrete.quantization import PostTrainingAffineQuantization, QuantizedArray
from concrete.torch import NumpyModule
# INPUT_OUTPUT_FEATURE is the number of input and output of each of the network layers.
# (as well as the input of the network itself)
# Currently, with 7 bits maximum, we can use 15 weights max in the theoretical case.
INPUT_OUTPUT_FEATURE = [1, 2, 3]
class FC(nn.Module):
"""Torch model for the tests"""
def __init__(self, input_output):
super().__init__()
self.fc1 = nn.Linear(in_features=input_output, out_features=input_output)
self.sigmoid1 = nn.Sigmoid()
self.fc2 = nn.Linear(in_features=input_output, out_features=input_output)
def forward(self, x):
"""Forward pass."""
out = self.fc1(x)
out = self.sigmoid1(out)
out = self.fc2(out)
return out
@pytest.mark.parametrize(
"model",
[pytest.param(FC)],
)
@pytest.mark.parametrize(
"input_output_feature",
[pytest.param(input_output_feature) for input_output_feature in INPUT_OUTPUT_FEATURE],
)
def test_quantized_module_compilation(
input_output_feature,
model,
seed_torch,
default_compilation_configuration,
check_is_good_execution,
):
"""Test a neural network compilation for FHE inference."""
# Seed torch
seed_torch()
n_bits = 2
# Define an input shape (n_examples, n_features)
input_shape = (50, input_output_feature)
# Build a random Quantized Fully Connected Neural Network
# Define the torch model
torch_fc_model = model(input_output_feature)
# Create random input
numpy_input = numpy.random.uniform(-100, 100, size=input_shape)
# Create corresponding numpy model
numpy_fc_model = NumpyModule(torch_fc_model)
# Quantize with post-training static method
post_training_quant = PostTrainingAffineQuantization(n_bits, numpy_fc_model)
quantized_model = post_training_quant.quantize_module(numpy_input)
# Quantize input
q_input = QuantizedArray(n_bits, numpy_input)
quantized_model(q_input)
# Compile
quantized_model.compile(q_input, default_compilation_configuration)
for x_q in q_input.qvalues:
x_q = numpy.expand_dims(x_q, 0)
check_is_good_execution(
fhe_circuit=quantized_model.forward_fhe,
function=quantized_model.forward,
args=[x_q.astype(numpy.uint8)],
postprocess_output_func=lambda x: quantized_model.dequantize_output(
x.astype(numpy.float32)
),
check_function=lambda lhs, rhs: numpy.isclose(lhs, rhs).all(),
verbose=False,
)