mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
chore: make check_is_good_execution a fixture and fix flaky tests using it
closes #1061
This commit is contained in:
@@ -4,14 +4,16 @@ import operator
|
||||
import random
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Type
|
||||
from typing import Any, Callable, Dict, Iterable, Type
|
||||
|
||||
import networkx as nx
|
||||
import networkx.algorithms.isomorphism as iso
|
||||
import numpy
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from concrete.common.compilation import CompilationConfiguration
|
||||
from concrete.common.fhe_circuit import FHECircuit
|
||||
from concrete.common.representation.intermediate import (
|
||||
ALL_IR_NODES,
|
||||
Add,
|
||||
@@ -293,3 +295,43 @@ def seed_torch():
|
||||
"""Fixture to seed torch"""
|
||||
|
||||
return function_to_seed_torch
|
||||
|
||||
|
||||
def check_is_good_execution_impl(
|
||||
fhe_circuit: FHECircuit,
|
||||
function: Callable,
|
||||
args: Iterable[Any],
|
||||
preprocess_input_func: Callable[[Any], Any] = lambda x: x,
|
||||
postprocess_output_func: Callable[[Any], Any] = lambda x: x,
|
||||
check_function: Callable[[Any, Any], bool] = numpy.array_equal,
|
||||
verbose: bool = True,
|
||||
):
|
||||
"""Run several times the check compiler_engine.run(*args) == function(*args). If always wrong,
|
||||
return an error. One can set the expected probability of success of one execution and the
|
||||
number of tests, to finetune the probability of bad luck, ie that we run several times the
|
||||
check and always have a wrong result."""
|
||||
nb_tries = 5
|
||||
|
||||
for i in range(1, nb_tries + 1):
|
||||
preprocessed_args = tuple(preprocess_input_func(val) for val in args)
|
||||
if check_function(
|
||||
last_engine_result := postprocess_output_func(fhe_circuit.run(*preprocessed_args)),
|
||||
last_function_result := postprocess_output_func(function(*preprocessed_args)),
|
||||
):
|
||||
# Good computation after i tries
|
||||
if verbose:
|
||||
print(f"Good computation after {i} tries")
|
||||
return
|
||||
|
||||
# Bad computation after nb_tries
|
||||
raise AssertionError(
|
||||
f"bad computation after {nb_tries} tries.\nLast engine result:\n{last_engine_result}\n"
|
||||
f"Last function result:\n{last_function_result}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def check_is_good_execution():
|
||||
"""Fixture to seed torch"""
|
||||
|
||||
return check_is_good_execution_impl
|
||||
|
||||
@@ -305,39 +305,13 @@ def negative_unary_f(func, x, y):
|
||||
return z
|
||||
|
||||
|
||||
def check_is_good_execution(compiler_engine, function, args, verbose=True):
|
||||
"""Run several times the check compiler_engine.run(*args) == function(*args). If always wrong,
|
||||
return an error. One can set the expected probability of success of one execution and the
|
||||
number of tests, to finetune the probability of bad luck, ie that we run several times the
|
||||
check and always have a wrong result."""
|
||||
expected_probability_of_success = 0.95
|
||||
nb_tries = 5
|
||||
expected_bad_luck = (1 - expected_probability_of_success) ** nb_tries
|
||||
|
||||
for i in range(1, nb_tries + 1):
|
||||
if numpy.array_equal(
|
||||
last_engine_result := compiler_engine.run(*args),
|
||||
last_function_result := function(*args),
|
||||
):
|
||||
# Good computation after i tries
|
||||
if verbose:
|
||||
print(f"Good computation after {i} tries")
|
||||
return
|
||||
|
||||
# Bad computation after nb_tries
|
||||
raise AssertionError(
|
||||
f"bad computation after {nb_tries} tries, which was supposed to happen with a "
|
||||
f"probability of {expected_bad_luck}.\nLast engine result:\n{last_engine_result}\n"
|
||||
f"Last function result:\n{last_function_result}"
|
||||
)
|
||||
|
||||
|
||||
def subtest_compile_and_run_unary_ufunc_correctness(
|
||||
ufunc,
|
||||
upper_function,
|
||||
input_ranges,
|
||||
tensor_shape,
|
||||
default_compilation_configuration,
|
||||
check_is_good_execution,
|
||||
):
|
||||
"""Test correctness of results when running a compiled function"""
|
||||
|
||||
@@ -378,6 +352,7 @@ def subtest_compile_and_run_binary_ufunc_correctness(
|
||||
input_ranges,
|
||||
tensor_shape,
|
||||
default_compilation_configuration,
|
||||
check_is_good_execution,
|
||||
):
|
||||
"""Test correctness of results when running a compiled function"""
|
||||
|
||||
@@ -418,7 +393,12 @@ def subtest_compile_and_run_binary_ufunc_correctness(
|
||||
@pytest.mark.parametrize(
|
||||
"tensor_shape", [pytest.param((), id="scalar"), pytest.param((3, 1, 2), id="tensor")]
|
||||
)
|
||||
def test_binary_ufunc_operations(ufunc, default_compilation_configuration, tensor_shape):
|
||||
def test_binary_ufunc_operations(
|
||||
ufunc,
|
||||
tensor_shape,
|
||||
default_compilation_configuration,
|
||||
check_is_good_execution,
|
||||
):
|
||||
"""Test biary functions which are in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC."""
|
||||
|
||||
run_multi_tlu_test = False
|
||||
@@ -436,6 +416,7 @@ def test_binary_ufunc_operations(ufunc, default_compilation_configuration, tenso
|
||||
((0, 4), (0, 5)),
|
||||
tensor_shape,
|
||||
default_compilation_configuration,
|
||||
check_is_good_execution,
|
||||
)
|
||||
subtest_compile_and_run_binary_ufunc_correctness(
|
||||
ufunc,
|
||||
@@ -444,6 +425,7 @@ def test_binary_ufunc_operations(ufunc, default_compilation_configuration, tenso
|
||||
((0, 4), (0, 5)),
|
||||
tensor_shape,
|
||||
default_compilation_configuration,
|
||||
check_is_good_execution,
|
||||
)
|
||||
if run_multi_tlu_test:
|
||||
subtest_compile_and_run_binary_ufunc_correctness(
|
||||
@@ -453,6 +435,7 @@ def test_binary_ufunc_operations(ufunc, default_compilation_configuration, tenso
|
||||
((0, 4), (0, 5)),
|
||||
tensor_shape,
|
||||
default_compilation_configuration,
|
||||
check_is_good_execution,
|
||||
)
|
||||
subtest_compile_and_run_binary_ufunc_correctness(
|
||||
ufunc,
|
||||
@@ -461,6 +444,7 @@ def test_binary_ufunc_operations(ufunc, default_compilation_configuration, tenso
|
||||
((0, 4), (0, 5)),
|
||||
tensor_shape,
|
||||
default_compilation_configuration,
|
||||
check_is_good_execution,
|
||||
)
|
||||
elif ufunc in [numpy.floor_divide, numpy.fmod, numpy.remainder, numpy.true_divide]:
|
||||
subtest_compile_and_run_binary_ufunc_correctness(
|
||||
@@ -470,6 +454,7 @@ def test_binary_ufunc_operations(ufunc, default_compilation_configuration, tenso
|
||||
((1, 5), (1, 5)),
|
||||
tensor_shape,
|
||||
default_compilation_configuration,
|
||||
check_is_good_execution,
|
||||
)
|
||||
if run_multi_tlu_test:
|
||||
subtest_compile_and_run_binary_ufunc_correctness(
|
||||
@@ -479,6 +464,7 @@ def test_binary_ufunc_operations(ufunc, default_compilation_configuration, tenso
|
||||
((1, 5), (1, 5)),
|
||||
tensor_shape,
|
||||
default_compilation_configuration,
|
||||
check_is_good_execution,
|
||||
)
|
||||
elif ufunc in [numpy.lcm, numpy.left_shift]:
|
||||
# Need small constants to keep results sufficiently small
|
||||
@@ -489,6 +475,7 @@ def test_binary_ufunc_operations(ufunc, default_compilation_configuration, tenso
|
||||
((0, 5), (0, 5)),
|
||||
tensor_shape,
|
||||
default_compilation_configuration,
|
||||
check_is_good_execution,
|
||||
)
|
||||
subtest_compile_and_run_binary_ufunc_correctness(
|
||||
ufunc,
|
||||
@@ -497,6 +484,7 @@ def test_binary_ufunc_operations(ufunc, default_compilation_configuration, tenso
|
||||
((0, 5), (0, 5)),
|
||||
tensor_shape,
|
||||
default_compilation_configuration,
|
||||
check_is_good_execution,
|
||||
)
|
||||
if run_multi_tlu_test:
|
||||
subtest_compile_and_run_binary_ufunc_correctness(
|
||||
@@ -508,6 +496,7 @@ def test_binary_ufunc_operations(ufunc, default_compilation_configuration, tenso
|
||||
((0, 5), (0, 5)),
|
||||
tensor_shape,
|
||||
default_compilation_configuration,
|
||||
check_is_good_execution,
|
||||
)
|
||||
subtest_compile_and_run_binary_ufunc_correctness(
|
||||
ufunc,
|
||||
@@ -518,6 +507,7 @@ def test_binary_ufunc_operations(ufunc, default_compilation_configuration, tenso
|
||||
((0, 5), (0, 5)),
|
||||
tensor_shape,
|
||||
default_compilation_configuration,
|
||||
check_is_good_execution,
|
||||
)
|
||||
elif ufunc in [numpy.ldexp]:
|
||||
# Need small constants to keep results sufficiently small
|
||||
@@ -528,6 +518,7 @@ def test_binary_ufunc_operations(ufunc, default_compilation_configuration, tenso
|
||||
((0, 5), (0, 5)),
|
||||
tensor_shape,
|
||||
default_compilation_configuration,
|
||||
check_is_good_execution,
|
||||
)
|
||||
if run_multi_tlu_test:
|
||||
subtest_compile_and_run_binary_ufunc_correctness(
|
||||
@@ -537,6 +528,7 @@ def test_binary_ufunc_operations(ufunc, default_compilation_configuration, tenso
|
||||
((0, 5), (0, 5)),
|
||||
tensor_shape,
|
||||
default_compilation_configuration,
|
||||
check_is_good_execution,
|
||||
)
|
||||
else:
|
||||
# General case
|
||||
@@ -547,6 +539,7 @@ def test_binary_ufunc_operations(ufunc, default_compilation_configuration, tenso
|
||||
((0, 5), (0, 5)),
|
||||
tensor_shape,
|
||||
default_compilation_configuration,
|
||||
check_is_good_execution,
|
||||
)
|
||||
subtest_compile_and_run_binary_ufunc_correctness(
|
||||
ufunc,
|
||||
@@ -555,6 +548,7 @@ def test_binary_ufunc_operations(ufunc, default_compilation_configuration, tenso
|
||||
((0, 5), (0, 5)),
|
||||
tensor_shape,
|
||||
default_compilation_configuration,
|
||||
check_is_good_execution,
|
||||
)
|
||||
if run_multi_tlu_test:
|
||||
subtest_compile_and_run_binary_ufunc_correctness(
|
||||
@@ -564,6 +558,7 @@ def test_binary_ufunc_operations(ufunc, default_compilation_configuration, tenso
|
||||
((0, 5), (0, 5)),
|
||||
tensor_shape,
|
||||
default_compilation_configuration,
|
||||
check_is_good_execution,
|
||||
)
|
||||
subtest_compile_and_run_binary_ufunc_correctness(
|
||||
ufunc,
|
||||
@@ -572,6 +567,7 @@ def test_binary_ufunc_operations(ufunc, default_compilation_configuration, tenso
|
||||
((0, 5), (0, 5)),
|
||||
tensor_shape,
|
||||
default_compilation_configuration,
|
||||
check_is_good_execution,
|
||||
)
|
||||
|
||||
# Negative inputs tests on compatible functions
|
||||
@@ -590,6 +586,7 @@ def test_binary_ufunc_operations(ufunc, default_compilation_configuration, tenso
|
||||
((0, 7), (0, 3)),
|
||||
tensor_shape,
|
||||
default_compilation_configuration,
|
||||
check_is_good_execution,
|
||||
)
|
||||
subtest_compile_and_run_binary_ufunc_correctness(
|
||||
ufunc,
|
||||
@@ -598,6 +595,7 @@ def test_binary_ufunc_operations(ufunc, default_compilation_configuration, tenso
|
||||
((0, 7), (0, 3)),
|
||||
tensor_shape,
|
||||
default_compilation_configuration,
|
||||
check_is_good_execution,
|
||||
)
|
||||
|
||||
|
||||
@@ -607,7 +605,9 @@ def test_binary_ufunc_operations(ufunc, default_compilation_configuration, tenso
|
||||
@pytest.mark.parametrize(
|
||||
"tensor_shape", [pytest.param((), id="scalar"), pytest.param((3, 1, 2), id="tensor")]
|
||||
)
|
||||
def test_unary_ufunc_operations(ufunc, default_compilation_configuration, tensor_shape):
|
||||
def test_unary_ufunc_operations(
|
||||
ufunc, tensor_shape, default_compilation_configuration, check_is_good_execution
|
||||
):
|
||||
"""Test unary functions which are in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC."""
|
||||
|
||||
if ufunc in [
|
||||
@@ -621,6 +621,7 @@ def test_unary_ufunc_operations(ufunc, default_compilation_configuration, tensor
|
||||
((0, 5), (0, 5)),
|
||||
tensor_shape,
|
||||
default_compilation_configuration,
|
||||
check_is_good_execution,
|
||||
)
|
||||
elif ufunc in [
|
||||
numpy.negative,
|
||||
@@ -632,6 +633,7 @@ def test_unary_ufunc_operations(ufunc, default_compilation_configuration, tensor
|
||||
((0, 5), (0, 5)),
|
||||
tensor_shape,
|
||||
default_compilation_configuration,
|
||||
check_is_good_execution,
|
||||
)
|
||||
elif ufunc in [
|
||||
numpy.arccosh,
|
||||
@@ -647,6 +649,7 @@ def test_unary_ufunc_operations(ufunc, default_compilation_configuration, tensor
|
||||
((1, 5), (1, 5)),
|
||||
tensor_shape,
|
||||
default_compilation_configuration,
|
||||
check_is_good_execution,
|
||||
)
|
||||
elif ufunc in [
|
||||
numpy.cosh,
|
||||
@@ -666,6 +669,7 @@ def test_unary_ufunc_operations(ufunc, default_compilation_configuration, tensor
|
||||
((0, 5), (0, 5)),
|
||||
tensor_shape,
|
||||
default_compilation_configuration,
|
||||
check_is_good_execution,
|
||||
)
|
||||
else:
|
||||
# Regular case for univariate functions
|
||||
@@ -675,6 +679,7 @@ def test_unary_ufunc_operations(ufunc, default_compilation_configuration, tensor
|
||||
((0, 5), (0, 5)),
|
||||
tensor_shape,
|
||||
default_compilation_configuration,
|
||||
check_is_good_execution,
|
||||
)
|
||||
|
||||
# Negative inputs tests on compatible functions
|
||||
@@ -696,6 +701,7 @@ def test_unary_ufunc_operations(ufunc, default_compilation_configuration, tensor
|
||||
((0, 7), (0, 3)),
|
||||
tensor_shape,
|
||||
default_compilation_configuration,
|
||||
check_is_good_execution,
|
||||
)
|
||||
|
||||
|
||||
@@ -811,7 +817,11 @@ def test_compile_and_run_correctness(
|
||||
],
|
||||
)
|
||||
def test_compile_and_run_correctness__for_prog_with_tlu(
|
||||
function, input_ranges, list_of_arg_names, default_compilation_configuration
|
||||
function,
|
||||
input_ranges,
|
||||
list_of_arg_names,
|
||||
default_compilation_configuration,
|
||||
check_is_good_execution,
|
||||
):
|
||||
"""Test correctness of results when running a compiled function which uses a TLU"""
|
||||
|
||||
@@ -1085,6 +1095,7 @@ def test_compile_and_run_tensor_correctness(
|
||||
test_input,
|
||||
use_check_good_exec,
|
||||
default_compilation_configuration,
|
||||
check_is_good_execution,
|
||||
):
|
||||
"""Test correctness of results when running a compiled function with tensor operators"""
|
||||
circuit = compile_numpy_function(
|
||||
@@ -1389,6 +1400,7 @@ def test_compile_and_run_lut_correctness(
|
||||
input_bits,
|
||||
list_of_arg_names,
|
||||
default_compilation_configuration,
|
||||
check_is_good_execution,
|
||||
):
|
||||
"""Test correctness of results when running a compiled function with LUT"""
|
||||
|
||||
@@ -1443,6 +1455,7 @@ def test_compile_and_run_negative_lut_correctness(
|
||||
table,
|
||||
bit_width,
|
||||
default_compilation_configuration,
|
||||
check_is_good_execution,
|
||||
):
|
||||
"""Test correctness when running a compiled function with LUT using negative values"""
|
||||
|
||||
@@ -1459,7 +1472,10 @@ def test_compile_and_run_negative_lut_correctness(
|
||||
check_is_good_execution(circuit, function, [value + offset])
|
||||
|
||||
|
||||
def test_compile_and_run_multi_lut_correctness(default_compilation_configuration):
|
||||
def test_compile_and_run_multi_lut_correctness(
|
||||
default_compilation_configuration,
|
||||
check_is_good_execution,
|
||||
):
|
||||
"""Test correctness of results when running a compiled function with Multi LUT"""
|
||||
|
||||
def function_to_compile(x):
|
||||
@@ -2000,7 +2016,11 @@ def test_compile_and_run_correctness_with_negative_values(
|
||||
],
|
||||
)
|
||||
def test_compile_and_run_correctness_with_negative_values_and_pbs(
|
||||
function, input_ranges, list_of_arg_names, default_compilation_configuration
|
||||
function,
|
||||
input_ranges,
|
||||
list_of_arg_names,
|
||||
default_compilation_configuration,
|
||||
check_is_good_execution,
|
||||
):
|
||||
"""Test correctness of results when running a compiled function, which has some negative
|
||||
intermediate values."""
|
||||
|
||||
@@ -39,7 +39,11 @@ class FC(nn.Module):
|
||||
[pytest.param(input_output_feature) for input_output_feature in INPUT_OUTPUT_FEATURE],
|
||||
)
|
||||
def test_quantized_module_compilation(
|
||||
input_output_feature, model, seed_torch, default_compilation_configuration
|
||||
input_output_feature,
|
||||
model,
|
||||
seed_torch,
|
||||
default_compilation_configuration,
|
||||
check_is_good_execution,
|
||||
):
|
||||
"""Test a neural network compilation for FHE inference."""
|
||||
# Seed torch
|
||||
@@ -68,25 +72,16 @@ def test_quantized_module_compilation(
|
||||
|
||||
# Compile
|
||||
quantized_model.compile(q_input, default_compilation_configuration)
|
||||
dequant_predictions = quantized_model.forward_and_dequant(q_input)
|
||||
|
||||
nb_tries = 5
|
||||
# Compare predictions between FHE and QuantizedModule
|
||||
for _ in range(nb_tries):
|
||||
homomorphic_predictions = []
|
||||
for x_q in q_input.qvalues:
|
||||
homomorphic_predictions.append(
|
||||
quantized_model.forward_fhe.run(numpy.array([x_q]).astype(numpy.uint8))
|
||||
)
|
||||
homomorphic_predictions = quantized_model.dequantize_output(
|
||||
numpy.array(homomorphic_predictions, dtype=numpy.float32)
|
||||
for x_q in q_input.qvalues:
|
||||
x_q = numpy.expand_dims(x_q, 0)
|
||||
check_is_good_execution(
|
||||
fhe_circuit=quantized_model.forward_fhe,
|
||||
function=quantized_model.forward,
|
||||
args=[x_q.astype(numpy.uint8)],
|
||||
postprocess_output_func=lambda x: quantized_model.dequantize_output(
|
||||
x.astype(numpy.float32)
|
||||
),
|
||||
check_function=lambda lhs, rhs: numpy.isclose(lhs, rhs).all(),
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
homomorphic_predictions = homomorphic_predictions.reshape(dequant_predictions.shape)
|
||||
|
||||
# Make sure homomorphic_predictions are the same as dequant_predictions
|
||||
if numpy.isclose(homomorphic_predictions, dequant_predictions).all():
|
||||
return
|
||||
|
||||
# Bad computation after nb_tries
|
||||
raise AssertionError(f"bad computation after {nb_tries} tries")
|
||||
|
||||
@@ -31,13 +31,19 @@ class FC(nn.Module):
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[pytest.param(FC, marks=pytest.mark.xfail)],
|
||||
[pytest.param(FC)],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"input_output_feature",
|
||||
[pytest.param(input_output_feature) for input_output_feature in INPUT_OUTPUT_FEATURE],
|
||||
)
|
||||
def test_compile_torch(input_output_feature, model, seed_torch, default_compilation_configuration):
|
||||
def test_compile_torch(
|
||||
input_output_feature,
|
||||
model,
|
||||
seed_torch,
|
||||
default_compilation_configuration,
|
||||
check_is_good_execution,
|
||||
):
|
||||
"""Test the different model architecture from torch numpy."""
|
||||
|
||||
# Seed torch
|
||||
@@ -46,12 +52,14 @@ def test_compile_torch(input_output_feature, model, seed_torch, default_compilat
|
||||
n_bits = 2
|
||||
|
||||
# Define an input shape (n_examples, n_features)
|
||||
n_examples = 10
|
||||
n_examples = 50
|
||||
|
||||
# Define the torch model
|
||||
torch_fc_model = model(input_output_feature)
|
||||
# Create random input
|
||||
inputset = [numpy.random.uniform(-1, 1, size=input_output_feature) for _ in range(n_examples)]
|
||||
inputset = [
|
||||
numpy.random.uniform(-100, 100, size=input_output_feature) for _ in range(n_examples)
|
||||
]
|
||||
|
||||
# Compile
|
||||
quantized_numpy_module = compile_torch_model(
|
||||
@@ -61,19 +69,19 @@ def test_compile_torch(input_output_feature, model, seed_torch, default_compilat
|
||||
n_bits=n_bits,
|
||||
)
|
||||
|
||||
# Quantize inputs all at once to have meaningful scale and zero point
|
||||
q_input = QuantizedArray(n_bits, numpy.array(inputset))
|
||||
|
||||
# Compare predictions between FHE and QuantizedModule
|
||||
clear_predictions = []
|
||||
homomorphic_predictions = []
|
||||
for numpy_input in inputset:
|
||||
q_input = QuantizedArray(n_bits, numpy_input)
|
||||
x_q = q_input.qvalues
|
||||
clear_predictions.append(quantized_numpy_module.forward(x_q))
|
||||
homomorphic_predictions.append(
|
||||
quantized_numpy_module.forward_fhe.run(numpy.array([x_q]).astype(numpy.uint8))
|
||||
for x_q in q_input.qvalues:
|
||||
x_q = numpy.expand_dims(x_q, 0)
|
||||
check_is_good_execution(
|
||||
fhe_circuit=quantized_numpy_module.forward_fhe,
|
||||
function=quantized_numpy_module.forward,
|
||||
args=[x_q.astype(numpy.uint8)],
|
||||
postprocess_output_func=lambda x: quantized_numpy_module.dequantize_output(
|
||||
x.astype(numpy.float32)
|
||||
),
|
||||
check_function=lambda lhs, rhs: numpy.isclose(lhs, rhs).all(),
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
clear_predictions = numpy.array(clear_predictions)
|
||||
homomorphic_predictions = numpy.array(homomorphic_predictions)
|
||||
|
||||
# Make sure homomorphic_predictions are the same as dequant_predictions
|
||||
assert numpy.array_equal(homomorphic_predictions, clear_predictions)
|
||||
|
||||
Reference in New Issue
Block a user