feat: add fork method to configuration to easily change a small setting

This commit is contained in:
Umut
2022-04-28 10:50:32 +02:00
parent 6fe148e72b
commit cc726154b6
4 changed files with 80 additions and 6 deletions

View File

@@ -4,9 +4,9 @@ Export everything that users might need.
from .compilation import (
Circuit,
DebugArtifacts,
Configuration,
Compiler,
Configuration,
DebugArtifacts,
EncryptionStatus,
compiler,
)

View File

@@ -82,9 +82,7 @@ class Compiler:
for param, status in parameter_encryption_statuses.items()
}
self.configuration = (
configuration if configuration is not None else Configuration()
)
self.configuration = configuration if configuration is not None else Configuration()
self.artifacts = artifacts if artifacts is not None else DebugArtifacts()
self.inputset = []

View File

@@ -2,7 +2,8 @@
Declaration of `Configuration` class.
"""
from typing import Optional
from copy import deepcopy
from typing import Optional, get_type_hints
_INSECURE_KEY_CACHE_LOCATION: Optional[str] = None
@@ -49,3 +50,34 @@ class Configuration:
"""
return _INSECURE_KEY_CACHE_LOCATION
def fork(self, **kwargs) -> "Configuration":
"""
Get a new configuration from another one specified changes.
Args:
**kwargs:
changes to make
Returns:
Configuration:
configuration that is forked from self and updated using kwargs
"""
result = deepcopy(self)
hints = get_type_hints(Configuration)
for name, value in kwargs.items():
if name not in hints:
raise TypeError(f"Unexpected keyword argument '{name}'")
hint = hints[name]
if not isinstance(value, hint): # type: ignore
raise TypeError(
f"Unexpected type for keyword argument '{name}' "
f"(expected '{hint.__name__}', got '{type(value).__name__}')"
)
setattr(result, name, value)
return result

View File

@@ -26,3 +26,47 @@ def test_configuration_bad_init(kwargs, expected_error, expected_message):
Configuration(**kwargs)
assert str(excinfo.value) == expected_message
def test_configuration_fork():
"""
Test `fork` method of `Configuration` class.
"""
config1 = Configuration(enable_unsafe_features=True, loop_parallelize=False)
config2 = config1.fork(enable_unsafe_features=False, loop_parallelize=True)
assert config1 is not config2
assert config1.enable_unsafe_features is True
assert config1.loop_parallelize is False
assert config2.enable_unsafe_features is False
assert config2.loop_parallelize is True
@pytest.mark.parametrize(
"kwargs,expected_error,expected_message",
[
pytest.param(
{"foo": False},
TypeError,
"Unexpected keyword argument 'foo'",
),
pytest.param(
{"dump_artifacts_on_unexpected_failures": "yes"},
TypeError,
"Unexpected type for keyword argument 'dump_artifacts_on_unexpected_failures' "
"(expected 'bool', got 'str')",
),
],
)
def test_configuration_bad_fork(kwargs, expected_error, expected_message):
"""
Test `fork` method of `Configuration` class with bad parameters.
"""
with pytest.raises(expected_error) as excinfo:
Configuration().fork(**kwargs)
assert str(excinfo.value) == expected_message