feat(compilation): provide a way to automatically generate a random inputset

This commit is contained in:
Umut
2021-10-22 16:17:15 +03:00
parent 9459675cfb
commit 70fbac7188
5 changed files with 341 additions and 10 deletions

View File

@@ -1,12 +1,13 @@
"""Test file for numpy compilation functions"""
import itertools
import random
from copy import deepcopy
import numpy
import pytest
from concrete.common.compilation import CompilationConfiguration
from concrete.common.data_types.integers import Integer
from concrete.common.data_types.integers import Integer, UnsignedInteger
from concrete.common.debugging import draw_graph, get_printable_graph
from concrete.common.extensions.table import LookupTable
from concrete.common.values import ClearTensor, EncryptedScalar, EncryptedTensor
@@ -1131,3 +1132,59 @@ def test_failure_for_signed_output(default_compilation_configuration):
"return(%2)\n"
)
# pylint: enable=line-too-long
def test_compile_with_random_inputset(default_compilation_configuration):
"""Test function for compile with random input set"""
configuration_to_use = deepcopy(default_compilation_configuration)
configuration_to_use.enable_unsafe_features = True
compile_numpy_function_into_op_graph(
lambda x: x + 1,
{"x": EncryptedScalar(UnsignedInteger(6))},
inputset="random",
compilation_configuration=configuration_to_use,
)
compile_numpy_function(
lambda x: x + 32,
{"x": EncryptedScalar(UnsignedInteger(6))},
inputset="random",
compilation_configuration=configuration_to_use,
)
def test_fail_compile_with_random_inputset(default_compilation_configuration):
"""Test function for failed compile with random input set"""
with pytest.raises(ValueError):
try:
compile_numpy_function_into_op_graph(
lambda x: x + 1,
{"x": EncryptedScalar(UnsignedInteger(3))},
inputset="unsupported",
compilation_configuration=default_compilation_configuration,
)
except Exception as error:
expected = (
"inputset can only be an iterable of tuples or the string 'random' "
"but you specified 'unsupported' for it"
)
assert str(error) == expected
raise
with pytest.raises(RuntimeError):
try:
compile_numpy_function(
lambda x: x + 1,
{"x": EncryptedScalar(UnsignedInteger(3))},
inputset="random",
compilation_configuration=default_compilation_configuration,
)
except Exception as error:
expected = (
"Random inputset generation is an unsafe feature "
"and should not be used if you don't know what you are doing"
)
assert str(error) == expected
raise

View File

@@ -0,0 +1,96 @@
"""Test file for numpy inputset helpers"""
import numpy as np
import pytest
from concrete.common.compilation import CompilationConfiguration
from concrete.common.data_types import Float, UnsignedInteger
from concrete.common.data_types.base import BaseDataType
from concrete.common.values import BaseValue, EncryptedScalar, EncryptedTensor
from concrete.numpy.np_inputset_helpers import _generate_random_inputset
def test_generate_random_inputset():
"""Test function for generate_random_inputset"""
inputset = _generate_random_inputset(
{
"x1": EncryptedScalar(UnsignedInteger(4)),
"x2": EncryptedTensor(UnsignedInteger(4), shape=(2, 3)),
"x3": EncryptedScalar(Float(64)),
"x4": EncryptedTensor(Float(64), shape=(3, 2)),
},
CompilationConfiguration(random_inputset_samples=15),
)
assert isinstance(inputset, list)
assert len(inputset) == 15
for sample in inputset:
assert isinstance(sample, tuple)
assert len(sample) == 4
assert isinstance(sample[0], int)
assert 0 <= sample[0] < 2 ** 4
assert isinstance(sample[1], np.ndarray)
assert sample[1].dtype == np.uint64
assert sample[1].shape == (2, 3)
assert (sample[1] >= 0).all()
assert (sample[1] < 2 ** 4).all()
assert isinstance(sample[2], float)
assert 0 <= sample[2] < 1
assert isinstance(sample[3], np.ndarray)
assert sample[3].dtype == np.float64
assert sample[3].shape == (3, 2)
assert (sample[3] >= 0).all()
assert (sample[3] < 1).all()
def test_fail_generate_random_inputset():
"""Test function for failed generate_random_inputset"""
class MockDtype(BaseDataType):
"""Unsupported dtype to check error messages"""
def __eq__(self, o: object) -> bool:
return False
def __str__(self):
return "MockDtype"
class MockValue(BaseValue):
"""Unsupported value to check error messages"""
def __init__(self):
super().__init__(MockDtype(), is_encrypted=True)
def __eq__(self, other: object) -> bool:
return False
def __str__(self):
return "MockValue"
with pytest.raises(ValueError):
try:
_generate_random_inputset(
{"x": MockValue()},
CompilationConfiguration(random_inputset_samples=15),
)
except Exception as error:
expected = "Random inputset cannot be generated for MockValue parameters"
assert str(error) == expected
raise
with pytest.raises(ValueError):
try:
_generate_random_inputset(
{"x": EncryptedScalar(MockDtype())},
CompilationConfiguration(random_inputset_samples=15),
)
except Exception as error:
expected = "Random inputset cannot be generated for parameters of type MockDtype"
assert str(error) == expected
raise