mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat: add fork method to configuration to easily change a small setting
This commit is contained in:
@@ -4,9 +4,9 @@ Export everything that users might need.
|
||||
|
||||
from .compilation import (
|
||||
Circuit,
|
||||
DebugArtifacts,
|
||||
Configuration,
|
||||
Compiler,
|
||||
Configuration,
|
||||
DebugArtifacts,
|
||||
EncryptionStatus,
|
||||
compiler,
|
||||
)
|
||||
|
||||
@@ -82,9 +82,7 @@ class Compiler:
|
||||
for param, status in parameter_encryption_statuses.items()
|
||||
}
|
||||
|
||||
self.configuration = (
|
||||
configuration if configuration is not None else Configuration()
|
||||
)
|
||||
self.configuration = configuration if configuration is not None else Configuration()
|
||||
self.artifacts = artifacts if artifacts is not None else DebugArtifacts()
|
||||
|
||||
self.inputset = []
|
||||
|
||||
@@ -2,7 +2,8 @@
|
||||
Declaration of `Configuration` class.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from copy import deepcopy
|
||||
from typing import Optional, get_type_hints
|
||||
|
||||
_INSECURE_KEY_CACHE_LOCATION: Optional[str] = None
|
||||
|
||||
@@ -49,3 +50,34 @@ class Configuration:
|
||||
"""
|
||||
|
||||
return _INSECURE_KEY_CACHE_LOCATION
|
||||
|
||||
def fork(self, **kwargs) -> "Configuration":
|
||||
"""
|
||||
Get a new configuration from another one specified changes.
|
||||
|
||||
Args:
|
||||
**kwargs:
|
||||
changes to make
|
||||
|
||||
Returns:
|
||||
Configuration:
|
||||
configuration that is forked from self and updated using kwargs
|
||||
"""
|
||||
|
||||
result = deepcopy(self)
|
||||
|
||||
hints = get_type_hints(Configuration)
|
||||
for name, value in kwargs.items():
|
||||
if name not in hints:
|
||||
raise TypeError(f"Unexpected keyword argument '{name}'")
|
||||
|
||||
hint = hints[name]
|
||||
if not isinstance(value, hint): # type: ignore
|
||||
raise TypeError(
|
||||
f"Unexpected type for keyword argument '{name}' "
|
||||
f"(expected '{hint.__name__}', got '{type(value).__name__}')"
|
||||
)
|
||||
|
||||
setattr(result, name, value)
|
||||
|
||||
return result
|
||||
|
||||
@@ -26,3 +26,47 @@ def test_configuration_bad_init(kwargs, expected_error, expected_message):
|
||||
Configuration(**kwargs)
|
||||
|
||||
assert str(excinfo.value) == expected_message
|
||||
|
||||
|
||||
def test_configuration_fork():
|
||||
"""
|
||||
Test `fork` method of `Configuration` class.
|
||||
"""
|
||||
|
||||
config1 = Configuration(enable_unsafe_features=True, loop_parallelize=False)
|
||||
config2 = config1.fork(enable_unsafe_features=False, loop_parallelize=True)
|
||||
|
||||
assert config1 is not config2
|
||||
|
||||
assert config1.enable_unsafe_features is True
|
||||
assert config1.loop_parallelize is False
|
||||
|
||||
assert config2.enable_unsafe_features is False
|
||||
assert config2.loop_parallelize is True
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"kwargs,expected_error,expected_message",
|
||||
[
|
||||
pytest.param(
|
||||
{"foo": False},
|
||||
TypeError,
|
||||
"Unexpected keyword argument 'foo'",
|
||||
),
|
||||
pytest.param(
|
||||
{"dump_artifacts_on_unexpected_failures": "yes"},
|
||||
TypeError,
|
||||
"Unexpected type for keyword argument 'dump_artifacts_on_unexpected_failures' "
|
||||
"(expected 'bool', got 'str')",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_configuration_bad_fork(kwargs, expected_error, expected_message):
|
||||
"""
|
||||
Test `fork` method of `Configuration` class with bad parameters.
|
||||
"""
|
||||
|
||||
with pytest.raises(expected_error) as excinfo:
|
||||
Configuration().fork(**kwargs)
|
||||
|
||||
assert str(excinfo.value) == expected_message
|
||||
|
||||
Reference in New Issue
Block a user