chore: make check_is_good_execution a fixture and fix flaky tests using it

closes #1061
This commit is contained in:
Arthur Meyre
2021-12-03 12:22:30 +01:00
parent 2f1e41e4fb
commit a0c26315ea
4 changed files with 137 additions and 72 deletions

View File

@@ -4,14 +4,16 @@ import operator
import random
import re
from pathlib import Path
from typing import Callable, Dict, Type
from typing import Any, Callable, Dict, Iterable, Type
import networkx as nx
import networkx.algorithms.isomorphism as iso
import numpy
import pytest
import torch
from concrete.common.compilation import CompilationConfiguration
from concrete.common.fhe_circuit import FHECircuit
from concrete.common.representation.intermediate import (
ALL_IR_NODES,
Add,
@@ -293,3 +295,43 @@ def seed_torch():
"""Fixture to seed torch"""
return function_to_seed_torch
def check_is_good_execution_impl(
fhe_circuit: FHECircuit,
function: Callable,
args: Iterable[Any],
preprocess_input_func: Callable[[Any], Any] = lambda x: x,
postprocess_output_func: Callable[[Any], Any] = lambda x: x,
check_function: Callable[[Any, Any], bool] = numpy.array_equal,
verbose: bool = True,
):
"""Run several times the check compiler_engine.run(*args) == function(*args). If always wrong,
return an error. One can set the expected probability of success of one execution and the
number of tests, to finetune the probability of bad luck, ie that we run several times the
check and always have a wrong result."""
nb_tries = 5
for i in range(1, nb_tries + 1):
preprocessed_args = tuple(preprocess_input_func(val) for val in args)
if check_function(
last_engine_result := postprocess_output_func(fhe_circuit.run(*preprocessed_args)),
last_function_result := postprocess_output_func(function(*preprocessed_args)),
):
# Good computation after i tries
if verbose:
print(f"Good computation after {i} tries")
return
# Bad computation after nb_tries
raise AssertionError(
f"bad computation after {nb_tries} tries.\nLast engine result:\n{last_engine_result}\n"
f"Last function result:\n{last_function_result}"
)
@pytest.fixture
def check_is_good_execution():
"""Fixture to seed torch"""
return check_is_good_execution_impl

View File

@@ -305,39 +305,13 @@ def negative_unary_f(func, x, y):
return z
def check_is_good_execution(compiler_engine, function, args, verbose=True):
"""Run several times the check compiler_engine.run(*args) == function(*args). If always wrong,
return an error. One can set the expected probability of success of one execution and the
number of tests, to finetune the probability of bad luck, ie that we run several times the
check and always have a wrong result."""
expected_probability_of_success = 0.95
nb_tries = 5
expected_bad_luck = (1 - expected_probability_of_success) ** nb_tries
for i in range(1, nb_tries + 1):
if numpy.array_equal(
last_engine_result := compiler_engine.run(*args),
last_function_result := function(*args),
):
# Good computation after i tries
if verbose:
print(f"Good computation after {i} tries")
return
# Bad computation after nb_tries
raise AssertionError(
f"bad computation after {nb_tries} tries, which was supposed to happen with a "
f"probability of {expected_bad_luck}.\nLast engine result:\n{last_engine_result}\n"
f"Last function result:\n{last_function_result}"
)
def subtest_compile_and_run_unary_ufunc_correctness(
ufunc,
upper_function,
input_ranges,
tensor_shape,
default_compilation_configuration,
check_is_good_execution,
):
"""Test correctness of results when running a compiled function"""
@@ -378,6 +352,7 @@ def subtest_compile_and_run_binary_ufunc_correctness(
input_ranges,
tensor_shape,
default_compilation_configuration,
check_is_good_execution,
):
"""Test correctness of results when running a compiled function"""
@@ -418,7 +393,12 @@ def subtest_compile_and_run_binary_ufunc_correctness(
@pytest.mark.parametrize(
"tensor_shape", [pytest.param((), id="scalar"), pytest.param((3, 1, 2), id="tensor")]
)
def test_binary_ufunc_operations(ufunc, default_compilation_configuration, tensor_shape):
def test_binary_ufunc_operations(
ufunc,
tensor_shape,
default_compilation_configuration,
check_is_good_execution,
):
"""Test biary functions which are in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC."""
run_multi_tlu_test = False
@@ -436,6 +416,7 @@ def test_binary_ufunc_operations(ufunc, default_compilation_configuration, tenso
((0, 4), (0, 5)),
tensor_shape,
default_compilation_configuration,
check_is_good_execution,
)
subtest_compile_and_run_binary_ufunc_correctness(
ufunc,
@@ -444,6 +425,7 @@ def test_binary_ufunc_operations(ufunc, default_compilation_configuration, tenso
((0, 4), (0, 5)),
tensor_shape,
default_compilation_configuration,
check_is_good_execution,
)
if run_multi_tlu_test:
subtest_compile_and_run_binary_ufunc_correctness(
@@ -453,6 +435,7 @@ def test_binary_ufunc_operations(ufunc, default_compilation_configuration, tenso
((0, 4), (0, 5)),
tensor_shape,
default_compilation_configuration,
check_is_good_execution,
)
subtest_compile_and_run_binary_ufunc_correctness(
ufunc,
@@ -461,6 +444,7 @@ def test_binary_ufunc_operations(ufunc, default_compilation_configuration, tenso
((0, 4), (0, 5)),
tensor_shape,
default_compilation_configuration,
check_is_good_execution,
)
elif ufunc in [numpy.floor_divide, numpy.fmod, numpy.remainder, numpy.true_divide]:
subtest_compile_and_run_binary_ufunc_correctness(
@@ -470,6 +454,7 @@ def test_binary_ufunc_operations(ufunc, default_compilation_configuration, tenso
((1, 5), (1, 5)),
tensor_shape,
default_compilation_configuration,
check_is_good_execution,
)
if run_multi_tlu_test:
subtest_compile_and_run_binary_ufunc_correctness(
@@ -479,6 +464,7 @@ def test_binary_ufunc_operations(ufunc, default_compilation_configuration, tenso
((1, 5), (1, 5)),
tensor_shape,
default_compilation_configuration,
check_is_good_execution,
)
elif ufunc in [numpy.lcm, numpy.left_shift]:
# Need small constants to keep results sufficiently small
@@ -489,6 +475,7 @@ def test_binary_ufunc_operations(ufunc, default_compilation_configuration, tenso
((0, 5), (0, 5)),
tensor_shape,
default_compilation_configuration,
check_is_good_execution,
)
subtest_compile_and_run_binary_ufunc_correctness(
ufunc,
@@ -497,6 +484,7 @@ def test_binary_ufunc_operations(ufunc, default_compilation_configuration, tenso
((0, 5), (0, 5)),
tensor_shape,
default_compilation_configuration,
check_is_good_execution,
)
if run_multi_tlu_test:
subtest_compile_and_run_binary_ufunc_correctness(
@@ -508,6 +496,7 @@ def test_binary_ufunc_operations(ufunc, default_compilation_configuration, tenso
((0, 5), (0, 5)),
tensor_shape,
default_compilation_configuration,
check_is_good_execution,
)
subtest_compile_and_run_binary_ufunc_correctness(
ufunc,
@@ -518,6 +507,7 @@ def test_binary_ufunc_operations(ufunc, default_compilation_configuration, tenso
((0, 5), (0, 5)),
tensor_shape,
default_compilation_configuration,
check_is_good_execution,
)
elif ufunc in [numpy.ldexp]:
# Need small constants to keep results sufficiently small
@@ -528,6 +518,7 @@ def test_binary_ufunc_operations(ufunc, default_compilation_configuration, tenso
((0, 5), (0, 5)),
tensor_shape,
default_compilation_configuration,
check_is_good_execution,
)
if run_multi_tlu_test:
subtest_compile_and_run_binary_ufunc_correctness(
@@ -537,6 +528,7 @@ def test_binary_ufunc_operations(ufunc, default_compilation_configuration, tenso
((0, 5), (0, 5)),
tensor_shape,
default_compilation_configuration,
check_is_good_execution,
)
else:
# General case
@@ -547,6 +539,7 @@ def test_binary_ufunc_operations(ufunc, default_compilation_configuration, tenso
((0, 5), (0, 5)),
tensor_shape,
default_compilation_configuration,
check_is_good_execution,
)
subtest_compile_and_run_binary_ufunc_correctness(
ufunc,
@@ -555,6 +548,7 @@ def test_binary_ufunc_operations(ufunc, default_compilation_configuration, tenso
((0, 5), (0, 5)),
tensor_shape,
default_compilation_configuration,
check_is_good_execution,
)
if run_multi_tlu_test:
subtest_compile_and_run_binary_ufunc_correctness(
@@ -564,6 +558,7 @@ def test_binary_ufunc_operations(ufunc, default_compilation_configuration, tenso
((0, 5), (0, 5)),
tensor_shape,
default_compilation_configuration,
check_is_good_execution,
)
subtest_compile_and_run_binary_ufunc_correctness(
ufunc,
@@ -572,6 +567,7 @@ def test_binary_ufunc_operations(ufunc, default_compilation_configuration, tenso
((0, 5), (0, 5)),
tensor_shape,
default_compilation_configuration,
check_is_good_execution,
)
# Negative inputs tests on compatible functions
@@ -590,6 +586,7 @@ def test_binary_ufunc_operations(ufunc, default_compilation_configuration, tenso
((0, 7), (0, 3)),
tensor_shape,
default_compilation_configuration,
check_is_good_execution,
)
subtest_compile_and_run_binary_ufunc_correctness(
ufunc,
@@ -598,6 +595,7 @@ def test_binary_ufunc_operations(ufunc, default_compilation_configuration, tenso
((0, 7), (0, 3)),
tensor_shape,
default_compilation_configuration,
check_is_good_execution,
)
@@ -607,7 +605,9 @@ def test_binary_ufunc_operations(ufunc, default_compilation_configuration, tenso
@pytest.mark.parametrize(
"tensor_shape", [pytest.param((), id="scalar"), pytest.param((3, 1, 2), id="tensor")]
)
def test_unary_ufunc_operations(ufunc, default_compilation_configuration, tensor_shape):
def test_unary_ufunc_operations(
ufunc, tensor_shape, default_compilation_configuration, check_is_good_execution
):
"""Test unary functions which are in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC."""
if ufunc in [
@@ -621,6 +621,7 @@ def test_unary_ufunc_operations(ufunc, default_compilation_configuration, tensor
((0, 5), (0, 5)),
tensor_shape,
default_compilation_configuration,
check_is_good_execution,
)
elif ufunc in [
numpy.negative,
@@ -632,6 +633,7 @@ def test_unary_ufunc_operations(ufunc, default_compilation_configuration, tensor
((0, 5), (0, 5)),
tensor_shape,
default_compilation_configuration,
check_is_good_execution,
)
elif ufunc in [
numpy.arccosh,
@@ -647,6 +649,7 @@ def test_unary_ufunc_operations(ufunc, default_compilation_configuration, tensor
((1, 5), (1, 5)),
tensor_shape,
default_compilation_configuration,
check_is_good_execution,
)
elif ufunc in [
numpy.cosh,
@@ -666,6 +669,7 @@ def test_unary_ufunc_operations(ufunc, default_compilation_configuration, tensor
((0, 5), (0, 5)),
tensor_shape,
default_compilation_configuration,
check_is_good_execution,
)
else:
# Regular case for univariate functions
@@ -675,6 +679,7 @@ def test_unary_ufunc_operations(ufunc, default_compilation_configuration, tensor
((0, 5), (0, 5)),
tensor_shape,
default_compilation_configuration,
check_is_good_execution,
)
# Negative inputs tests on compatible functions
@@ -696,6 +701,7 @@ def test_unary_ufunc_operations(ufunc, default_compilation_configuration, tensor
((0, 7), (0, 3)),
tensor_shape,
default_compilation_configuration,
check_is_good_execution,
)
@@ -811,7 +817,11 @@ def test_compile_and_run_correctness(
],
)
def test_compile_and_run_correctness__for_prog_with_tlu(
function, input_ranges, list_of_arg_names, default_compilation_configuration
function,
input_ranges,
list_of_arg_names,
default_compilation_configuration,
check_is_good_execution,
):
"""Test correctness of results when running a compiled function which uses a TLU"""
@@ -1085,6 +1095,7 @@ def test_compile_and_run_tensor_correctness(
test_input,
use_check_good_exec,
default_compilation_configuration,
check_is_good_execution,
):
"""Test correctness of results when running a compiled function with tensor operators"""
circuit = compile_numpy_function(
@@ -1389,6 +1400,7 @@ def test_compile_and_run_lut_correctness(
input_bits,
list_of_arg_names,
default_compilation_configuration,
check_is_good_execution,
):
"""Test correctness of results when running a compiled function with LUT"""
@@ -1443,6 +1455,7 @@ def test_compile_and_run_negative_lut_correctness(
table,
bit_width,
default_compilation_configuration,
check_is_good_execution,
):
"""Test correctness when running a compiled function with LUT using negative values"""
@@ -1459,7 +1472,10 @@ def test_compile_and_run_negative_lut_correctness(
check_is_good_execution(circuit, function, [value + offset])
def test_compile_and_run_multi_lut_correctness(default_compilation_configuration):
def test_compile_and_run_multi_lut_correctness(
default_compilation_configuration,
check_is_good_execution,
):
"""Test correctness of results when running a compiled function with Multi LUT"""
def function_to_compile(x):
@@ -2000,7 +2016,11 @@ def test_compile_and_run_correctness_with_negative_values(
],
)
def test_compile_and_run_correctness_with_negative_values_and_pbs(
function, input_ranges, list_of_arg_names, default_compilation_configuration
function,
input_ranges,
list_of_arg_names,
default_compilation_configuration,
check_is_good_execution,
):
"""Test correctness of results when running a compiled function, which has some negative
intermediate values."""

View File

@@ -39,7 +39,11 @@ class FC(nn.Module):
[pytest.param(input_output_feature) for input_output_feature in INPUT_OUTPUT_FEATURE],
)
def test_quantized_module_compilation(
input_output_feature, model, seed_torch, default_compilation_configuration
input_output_feature,
model,
seed_torch,
default_compilation_configuration,
check_is_good_execution,
):
"""Test a neural network compilation for FHE inference."""
# Seed torch
@@ -68,25 +72,16 @@ def test_quantized_module_compilation(
# Compile
quantized_model.compile(q_input, default_compilation_configuration)
dequant_predictions = quantized_model.forward_and_dequant(q_input)
nb_tries = 5
# Compare predictions between FHE and QuantizedModule
for _ in range(nb_tries):
homomorphic_predictions = []
for x_q in q_input.qvalues:
homomorphic_predictions.append(
quantized_model.forward_fhe.run(numpy.array([x_q]).astype(numpy.uint8))
)
homomorphic_predictions = quantized_model.dequantize_output(
numpy.array(homomorphic_predictions, dtype=numpy.float32)
for x_q in q_input.qvalues:
x_q = numpy.expand_dims(x_q, 0)
check_is_good_execution(
fhe_circuit=quantized_model.forward_fhe,
function=quantized_model.forward,
args=[x_q.astype(numpy.uint8)],
postprocess_output_func=lambda x: quantized_model.dequantize_output(
x.astype(numpy.float32)
),
check_function=lambda lhs, rhs: numpy.isclose(lhs, rhs).all(),
verbose=False,
)
homomorphic_predictions = homomorphic_predictions.reshape(dequant_predictions.shape)
# Make sure homomorphic_predictions are the same as dequant_predictions
if numpy.isclose(homomorphic_predictions, dequant_predictions).all():
return
# Bad computation after nb_tries
raise AssertionError(f"bad computation after {nb_tries} tries")

View File

@@ -31,13 +31,19 @@ class FC(nn.Module):
@pytest.mark.parametrize(
"model",
[pytest.param(FC, marks=pytest.mark.xfail)],
[pytest.param(FC)],
)
@pytest.mark.parametrize(
"input_output_feature",
[pytest.param(input_output_feature) for input_output_feature in INPUT_OUTPUT_FEATURE],
)
def test_compile_torch(input_output_feature, model, seed_torch, default_compilation_configuration):
def test_compile_torch(
input_output_feature,
model,
seed_torch,
default_compilation_configuration,
check_is_good_execution,
):
"""Test the different model architecture from torch numpy."""
# Seed torch
@@ -46,12 +52,14 @@ def test_compile_torch(input_output_feature, model, seed_torch, default_compilat
n_bits = 2
# Define an input shape (n_examples, n_features)
n_examples = 10
n_examples = 50
# Define the torch model
torch_fc_model = model(input_output_feature)
# Create random input
inputset = [numpy.random.uniform(-1, 1, size=input_output_feature) for _ in range(n_examples)]
inputset = [
numpy.random.uniform(-100, 100, size=input_output_feature) for _ in range(n_examples)
]
# Compile
quantized_numpy_module = compile_torch_model(
@@ -61,19 +69,19 @@ def test_compile_torch(input_output_feature, model, seed_torch, default_compilat
n_bits=n_bits,
)
# Quantize inputs all at once to have meaningful scale and zero point
q_input = QuantizedArray(n_bits, numpy.array(inputset))
# Compare predictions between FHE and QuantizedModule
clear_predictions = []
homomorphic_predictions = []
for numpy_input in inputset:
q_input = QuantizedArray(n_bits, numpy_input)
x_q = q_input.qvalues
clear_predictions.append(quantized_numpy_module.forward(x_q))
homomorphic_predictions.append(
quantized_numpy_module.forward_fhe.run(numpy.array([x_q]).astype(numpy.uint8))
for x_q in q_input.qvalues:
x_q = numpy.expand_dims(x_q, 0)
check_is_good_execution(
fhe_circuit=quantized_numpy_module.forward_fhe,
function=quantized_numpy_module.forward,
args=[x_q.astype(numpy.uint8)],
postprocess_output_func=lambda x: quantized_numpy_module.dequantize_output(
x.astype(numpy.float32)
),
check_function=lambda lhs, rhs: numpy.isclose(lhs, rhs).all(),
verbose=False,
)
clear_predictions = numpy.array(clear_predictions)
homomorphic_predictions = numpy.array(homomorphic_predictions)
# Make sure homomorphic_predictions are the same as dequant_predictions
assert numpy.array_equal(homomorphic_predictions, clear_predictions)