mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
committed by
Benoit Chevallier
parent
60d8079303
commit
ec396effb2
@@ -1,6 +1,7 @@
|
||||
"""PyTest configuration file"""
|
||||
import json
|
||||
import operator
|
||||
import random
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Type
|
||||
@@ -8,6 +9,7 @@ from typing import Callable, Dict, Type
|
||||
import networkx as nx
|
||||
import networkx.algorithms.isomorphism as iso
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from concrete.common.compilation import CompilationConfiguration
|
||||
from concrete.common.representation.intermediate import (
|
||||
@@ -276,3 +278,18 @@ REMOVE_COLOR_CODES_RE = re.compile(r"\x1b[^m]*m")
|
||||
def remove_color_codes():
|
||||
"""Return the re object to remove color codes"""
|
||||
return lambda x: REMOVE_COLOR_CODES_RE.sub("", x)
|
||||
|
||||
|
||||
def function_to_seed_torch():
|
||||
"""Function to seed torch"""
|
||||
|
||||
# Seed torch with something which is seed by pytest-randomly
|
||||
torch.manual_seed(random.randint(0, 2 ** 64 - 1))
|
||||
torch.use_deterministic_algorithms(True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def seed_torch():
|
||||
"""Fixture to seed torch"""
|
||||
|
||||
return function_to_seed_torch
|
||||
|
||||
@@ -80,12 +80,14 @@ N_BITS_ATOL_TUPLE_LIST = [
|
||||
pytest.param(FC, (100, 32 * 32 * 3)),
|
||||
],
|
||||
)
|
||||
def test_quantized_linear(model, input_shape, n_bits, atol):
|
||||
def test_quantized_linear(model, input_shape, n_bits, atol, seed_torch):
|
||||
"""Test the quantized module with a post-training static quantization.
|
||||
|
||||
With n_bits>>0 we expect the results of the quantized module
|
||||
to be the same as the standard module.
|
||||
"""
|
||||
# Seed torch
|
||||
seed_torch()
|
||||
# Define the torch model
|
||||
torch_fc_model = model()
|
||||
# Create random input
|
||||
|
||||
@@ -66,9 +66,11 @@ class FC(nn.Module):
|
||||
pytest.param(FC, (100, 32 * 32 * 3)),
|
||||
],
|
||||
)
|
||||
def test_torch_to_numpy(model, input_shape):
|
||||
def test_torch_to_numpy(model, input_shape, seed_torch):
|
||||
"""Test the different model architecture from torch numpy."""
|
||||
|
||||
# Seed torch
|
||||
seed_torch()
|
||||
# Define the torch model
|
||||
torch_fc_model = model()
|
||||
# Create random input
|
||||
@@ -104,9 +106,10 @@ def test_torch_to_numpy(model, input_shape):
|
||||
"model, incompatible_layer",
|
||||
[pytest.param(CNN, "Conv2d")],
|
||||
)
|
||||
def test_raises(model, incompatible_layer):
|
||||
def test_raises(model, incompatible_layer, seed_torch):
|
||||
"""Function to test incompatible layers."""
|
||||
|
||||
seed_torch()
|
||||
torch_incompatible_model = model()
|
||||
expected_errmsg = (
|
||||
f"The following module is currently not implemented: {incompatible_layer}. "
|
||||
|
||||
Reference in New Issue
Block a user