mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
feat(compilation): provide a way to automatically generate a random inputset
This commit is contained in:
@@ -1,12 +1,13 @@
|
||||
"""Test file for numpy compilation functions"""
|
||||
import itertools
|
||||
import random
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy
|
||||
import pytest
|
||||
|
||||
from concrete.common.compilation import CompilationConfiguration
|
||||
from concrete.common.data_types.integers import Integer
|
||||
from concrete.common.data_types.integers import Integer, UnsignedInteger
|
||||
from concrete.common.debugging import draw_graph, get_printable_graph
|
||||
from concrete.common.extensions.table import LookupTable
|
||||
from concrete.common.values import ClearTensor, EncryptedScalar, EncryptedTensor
|
||||
@@ -1131,3 +1132,59 @@ def test_failure_for_signed_output(default_compilation_configuration):
|
||||
"return(%2)\n"
|
||||
)
|
||||
# pylint: enable=line-too-long
|
||||
|
||||
|
||||
def test_compile_with_random_inputset(default_compilation_configuration):
|
||||
"""Test function for compile with random input set"""
|
||||
|
||||
configuration_to_use = deepcopy(default_compilation_configuration)
|
||||
configuration_to_use.enable_unsafe_features = True
|
||||
|
||||
compile_numpy_function_into_op_graph(
|
||||
lambda x: x + 1,
|
||||
{"x": EncryptedScalar(UnsignedInteger(6))},
|
||||
inputset="random",
|
||||
compilation_configuration=configuration_to_use,
|
||||
)
|
||||
compile_numpy_function(
|
||||
lambda x: x + 32,
|
||||
{"x": EncryptedScalar(UnsignedInteger(6))},
|
||||
inputset="random",
|
||||
compilation_configuration=configuration_to_use,
|
||||
)
|
||||
|
||||
|
||||
def test_fail_compile_with_random_inputset(default_compilation_configuration):
|
||||
"""Test function for failed compile with random input set"""
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
try:
|
||||
compile_numpy_function_into_op_graph(
|
||||
lambda x: x + 1,
|
||||
{"x": EncryptedScalar(UnsignedInteger(3))},
|
||||
inputset="unsupported",
|
||||
compilation_configuration=default_compilation_configuration,
|
||||
)
|
||||
except Exception as error:
|
||||
expected = (
|
||||
"inputset can only be an iterable of tuples or the string 'random' "
|
||||
"but you specified 'unsupported' for it"
|
||||
)
|
||||
assert str(error) == expected
|
||||
raise
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
try:
|
||||
compile_numpy_function(
|
||||
lambda x: x + 1,
|
||||
{"x": EncryptedScalar(UnsignedInteger(3))},
|
||||
inputset="random",
|
||||
compilation_configuration=default_compilation_configuration,
|
||||
)
|
||||
except Exception as error:
|
||||
expected = (
|
||||
"Random inputset generation is an unsafe feature "
|
||||
"and should not be used if you don't know what you are doing"
|
||||
)
|
||||
assert str(error) == expected
|
||||
raise
|
||||
|
||||
96
tests/numpy/test_np_inputset_helpers.py
Normal file
96
tests/numpy/test_np_inputset_helpers.py
Normal file
@@ -0,0 +1,96 @@
|
||||
"""Test file for numpy inputset helpers"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from concrete.common.compilation import CompilationConfiguration
|
||||
from concrete.common.data_types import Float, UnsignedInteger
|
||||
from concrete.common.data_types.base import BaseDataType
|
||||
from concrete.common.values import BaseValue, EncryptedScalar, EncryptedTensor
|
||||
from concrete.numpy.np_inputset_helpers import _generate_random_inputset
|
||||
|
||||
|
||||
def test_generate_random_inputset():
|
||||
"""Test function for generate_random_inputset"""
|
||||
|
||||
inputset = _generate_random_inputset(
|
||||
{
|
||||
"x1": EncryptedScalar(UnsignedInteger(4)),
|
||||
"x2": EncryptedTensor(UnsignedInteger(4), shape=(2, 3)),
|
||||
"x3": EncryptedScalar(Float(64)),
|
||||
"x4": EncryptedTensor(Float(64), shape=(3, 2)),
|
||||
},
|
||||
CompilationConfiguration(random_inputset_samples=15),
|
||||
)
|
||||
|
||||
assert isinstance(inputset, list)
|
||||
assert len(inputset) == 15
|
||||
|
||||
for sample in inputset:
|
||||
assert isinstance(sample, tuple)
|
||||
assert len(sample) == 4
|
||||
|
||||
assert isinstance(sample[0], int)
|
||||
assert 0 <= sample[0] < 2 ** 4
|
||||
|
||||
assert isinstance(sample[1], np.ndarray)
|
||||
assert sample[1].dtype == np.uint64
|
||||
assert sample[1].shape == (2, 3)
|
||||
assert (sample[1] >= 0).all()
|
||||
assert (sample[1] < 2 ** 4).all()
|
||||
|
||||
assert isinstance(sample[2], float)
|
||||
assert 0 <= sample[2] < 1
|
||||
|
||||
assert isinstance(sample[3], np.ndarray)
|
||||
assert sample[3].dtype == np.float64
|
||||
assert sample[3].shape == (3, 2)
|
||||
assert (sample[3] >= 0).all()
|
||||
assert (sample[3] < 1).all()
|
||||
|
||||
|
||||
def test_fail_generate_random_inputset():
|
||||
"""Test function for failed generate_random_inputset"""
|
||||
|
||||
class MockDtype(BaseDataType):
|
||||
"""Unsupported dtype to check error messages"""
|
||||
|
||||
def __eq__(self, o: object) -> bool:
|
||||
return False
|
||||
|
||||
def __str__(self):
|
||||
return "MockDtype"
|
||||
|
||||
class MockValue(BaseValue):
|
||||
"""Unsupported value to check error messages"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(MockDtype(), is_encrypted=True)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return False
|
||||
|
||||
def __str__(self):
|
||||
return "MockValue"
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
try:
|
||||
_generate_random_inputset(
|
||||
{"x": MockValue()},
|
||||
CompilationConfiguration(random_inputset_samples=15),
|
||||
)
|
||||
except Exception as error:
|
||||
expected = "Random inputset cannot be generated for MockValue parameters"
|
||||
assert str(error) == expected
|
||||
raise
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
try:
|
||||
_generate_random_inputset(
|
||||
{"x": EncryptedScalar(MockDtype())},
|
||||
CompilationConfiguration(random_inputset_samples=15),
|
||||
)
|
||||
except Exception as error:
|
||||
expected = "Random inputset cannot be generated for parameters of type MockDtype"
|
||||
assert str(error) == expected
|
||||
raise
|
||||
Reference in New Issue
Block a user