diff --git a/tests/conftest.py b/tests/conftest.py index e3216e84d..b9d7ae27d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,7 @@ """PyTest configuration file""" import json import operator +import random import re from pathlib import Path from typing import Callable, Dict, Type @@ -8,6 +9,7 @@ from typing import Callable, Dict, Type import networkx as nx import networkx.algorithms.isomorphism as iso import pytest +import torch from concrete.common.compilation import CompilationConfiguration from concrete.common.representation.intermediate import ( @@ -276,3 +278,18 @@ REMOVE_COLOR_CODES_RE = re.compile(r"\x1b[^m]*m") def remove_color_codes(): """Return the re object to remove color codes""" return lambda x: REMOVE_COLOR_CODES_RE.sub("", x) + + +def function_to_seed_torch(): + """Function to seed torch""" + + # Seed torch with something which is seed by pytest-randomly + torch.manual_seed(random.randint(0, 2 ** 64 - 1)) + torch.use_deterministic_algorithms(True) + + +@pytest.fixture +def seed_torch(): + """Fixture to seed torch""" + + return function_to_seed_torch diff --git a/tests/quantization/test_quantized_module.py b/tests/quantization/test_quantized_module.py index fb12bebbb..0e77357f8 100644 --- a/tests/quantization/test_quantized_module.py +++ b/tests/quantization/test_quantized_module.py @@ -80,12 +80,14 @@ N_BITS_ATOL_TUPLE_LIST = [ pytest.param(FC, (100, 32 * 32 * 3)), ], ) -def test_quantized_linear(model, input_shape, n_bits, atol): +def test_quantized_linear(model, input_shape, n_bits, atol, seed_torch): """Test the quantized module with a post-training static quantization. With n_bits>>0 we expect the results of the quantized module to be the same as the standard module. """ + # Seed torch + seed_torch() # Define the torch model torch_fc_model = model() # Create random input diff --git a/tests/torch/test_torch_to_numpy.py b/tests/torch/test_torch_to_numpy.py index b194f3648..127d6bf01 100644 --- a/tests/torch/test_torch_to_numpy.py +++ b/tests/torch/test_torch_to_numpy.py @@ -66,9 +66,11 @@ class FC(nn.Module): pytest.param(FC, (100, 32 * 32 * 3)), ], ) -def test_torch_to_numpy(model, input_shape): +def test_torch_to_numpy(model, input_shape, seed_torch): """Test the different model architecture from torch numpy.""" + # Seed torch + seed_torch() # Define the torch model torch_fc_model = model() # Create random input @@ -104,9 +106,10 @@ def test_torch_to_numpy(model, input_shape): "model, incompatible_layer", [pytest.param(CNN, "Conv2d")], ) -def test_raises(model, incompatible_layer): +def test_raises(model, incompatible_layer, seed_torch): """Function to test incompatible layers.""" + seed_torch() torch_incompatible_model = model() expected_errmsg = ( f"The following module is currently not implemented: {incompatible_layer}. "