Files
concrete/tests/numpy/test_np_inputset_helpers.py

97 lines
3.0 KiB
Python

"""Test file for numpy inputset helpers"""
import numpy as np
import pytest
from concrete.common.compilation import CompilationConfiguration
from concrete.common.data_types import Float, UnsignedInteger
from concrete.common.data_types.base import BaseDataType
from concrete.common.values import BaseValue, EncryptedScalar, EncryptedTensor
from concrete.numpy.np_inputset_helpers import _generate_random_inputset
def test_generate_random_inputset():
"""Test function for generate_random_inputset"""
inputset = _generate_random_inputset(
{
"x1": EncryptedScalar(UnsignedInteger(4)),
"x2": EncryptedTensor(UnsignedInteger(4), shape=(2, 3)),
"x3": EncryptedScalar(Float(64)),
"x4": EncryptedTensor(Float(64), shape=(3, 2)),
},
CompilationConfiguration(random_inputset_samples=15),
)
assert isinstance(inputset, list)
assert len(inputset) == 15
for sample in inputset:
assert isinstance(sample, tuple)
assert len(sample) == 4
assert isinstance(sample[0], int)
assert 0 <= sample[0] < 2 ** 4
assert isinstance(sample[1], np.ndarray)
assert sample[1].dtype == np.uint64
assert sample[1].shape == (2, 3)
assert (sample[1] >= 0).all()
assert (sample[1] < 2 ** 4).all()
assert isinstance(sample[2], float)
assert 0 <= sample[2] < 1
assert isinstance(sample[3], np.ndarray)
assert sample[3].dtype == np.float64
assert sample[3].shape == (3, 2)
assert (sample[3] >= 0).all()
assert (sample[3] < 1).all()
def test_fail_generate_random_inputset():
"""Test function for failed generate_random_inputset"""
class MockDtype(BaseDataType):
"""Unsupported dtype to check error messages"""
def __eq__(self, o: object) -> bool:
return False
def __str__(self):
return "MockDtype"
class MockValue(BaseValue):
"""Unsupported value to check error messages"""
def __init__(self):
super().__init__(MockDtype(), is_encrypted=True)
def __eq__(self, other: object) -> bool:
return False
def __str__(self):
return "MockValue"
with pytest.raises(ValueError):
try:
_generate_random_inputset(
{"x": MockValue()},
CompilationConfiguration(random_inputset_samples=15),
)
except Exception as error:
expected = "Random inputset cannot be generated for MockValue parameters"
assert str(error) == expected
raise
with pytest.raises(ValueError):
try:
_generate_random_inputset(
{"x": EncryptedScalar(MockDtype())},
CompilationConfiguration(random_inputset_samples=15),
)
except Exception as error:
expected = "Random inputset cannot be generated for parameters of type MockDtype"
assert str(error) == expected
raise