chore: seed torch as much as possible

closes #877
This commit is contained in:
Benoit Chevallier-Mames
2021-11-23 17:54:29 +01:00
committed by Benoit Chevallier
parent 60d8079303
commit ec396effb2
3 changed files with 25 additions and 3 deletions

View File

@@ -1,6 +1,7 @@
"""PyTest configuration file"""
import json
import operator
import random
import re
from pathlib import Path
from typing import Callable, Dict, Type
@@ -8,6 +9,7 @@ from typing import Callable, Dict, Type
import networkx as nx
import networkx.algorithms.isomorphism as iso
import pytest
import torch
from concrete.common.compilation import CompilationConfiguration
from concrete.common.representation.intermediate import (
@@ -276,3 +278,18 @@ REMOVE_COLOR_CODES_RE = re.compile(r"\x1b[^m]*m")
def remove_color_codes():
"""Return the re object to remove color codes"""
return lambda x: REMOVE_COLOR_CODES_RE.sub("", x)
def function_to_seed_torch():
"""Function to seed torch"""
# Seed torch with something which is seed by pytest-randomly
torch.manual_seed(random.randint(0, 2 ** 64 - 1))
torch.use_deterministic_algorithms(True)
@pytest.fixture
def seed_torch():
"""Fixture to seed torch"""
return function_to_seed_torch

View File

@@ -80,12 +80,14 @@ N_BITS_ATOL_TUPLE_LIST = [
pytest.param(FC, (100, 32 * 32 * 3)),
],
)
def test_quantized_linear(model, input_shape, n_bits, atol):
def test_quantized_linear(model, input_shape, n_bits, atol, seed_torch):
"""Test the quantized module with a post-training static quantization.
With n_bits>>0 we expect the results of the quantized module
to be the same as the standard module.
"""
# Seed torch
seed_torch()
# Define the torch model
torch_fc_model = model()
# Create random input

View File

@@ -66,9 +66,11 @@ class FC(nn.Module):
pytest.param(FC, (100, 32 * 32 * 3)),
],
)
def test_torch_to_numpy(model, input_shape):
def test_torch_to_numpy(model, input_shape, seed_torch):
"""Test the different model architecture from torch numpy."""
# Seed torch
seed_torch()
# Define the torch model
torch_fc_model = model()
# Create random input
@@ -104,9 +106,10 @@ def test_torch_to_numpy(model, input_shape):
"model, incompatible_layer",
[pytest.param(CNN, "Conv2d")],
)
def test_raises(model, incompatible_layer):
def test_raises(model, incompatible_layer, seed_torch):
"""Function to test incompatible layers."""
seed_torch()
torch_incompatible_model = model()
expected_errmsg = (
f"The following module is currently not implemented: {incompatible_layer}. "