From cc726154b648dbe35d30db52c63111dc47335a7c Mon Sep 17 00:00:00 2001 From: Umut Date: Thu, 28 Apr 2022 10:50:32 +0200 Subject: [PATCH] feat: add fork method to configuration to easily change a small setting --- concrete/numpy/__init__.py | 4 +- concrete/numpy/compilation/compiler.py | 4 +- concrete/numpy/compilation/configuration.py | 34 +++++++++++++++- tests/compilation/test_configuration.py | 44 +++++++++++++++++++++ 4 files changed, 80 insertions(+), 6 deletions(-) diff --git a/concrete/numpy/__init__.py b/concrete/numpy/__init__.py index 4541e5696..3f8d73b3d 100644 --- a/concrete/numpy/__init__.py +++ b/concrete/numpy/__init__.py @@ -4,9 +4,9 @@ Export everything that users might need. from .compilation import ( Circuit, - DebugArtifacts, - Configuration, Compiler, + Configuration, + DebugArtifacts, EncryptionStatus, compiler, ) diff --git a/concrete/numpy/compilation/compiler.py b/concrete/numpy/compilation/compiler.py index c06ad72bc..5b78432c4 100644 --- a/concrete/numpy/compilation/compiler.py +++ b/concrete/numpy/compilation/compiler.py @@ -82,9 +82,7 @@ class Compiler: for param, status in parameter_encryption_statuses.items() } - self.configuration = ( - configuration if configuration is not None else Configuration() - ) + self.configuration = configuration if configuration is not None else Configuration() self.artifacts = artifacts if artifacts is not None else DebugArtifacts() self.inputset = [] diff --git a/concrete/numpy/compilation/configuration.py b/concrete/numpy/compilation/configuration.py index 355d2f622..fdd87d55c 100644 --- a/concrete/numpy/compilation/configuration.py +++ b/concrete/numpy/compilation/configuration.py @@ -2,7 +2,8 @@ Declaration of `Configuration` class. """ -from typing import Optional +from copy import deepcopy +from typing import Optional, get_type_hints _INSECURE_KEY_CACHE_LOCATION: Optional[str] = None @@ -49,3 +50,34 @@ class Configuration: """ return _INSECURE_KEY_CACHE_LOCATION + + def fork(self, **kwargs) -> "Configuration": + """ + Get a new configuration from another one specified changes. + + Args: + **kwargs: + changes to make + + Returns: + Configuration: + configuration that is forked from self and updated using kwargs + """ + + result = deepcopy(self) + + hints = get_type_hints(Configuration) + for name, value in kwargs.items(): + if name not in hints: + raise TypeError(f"Unexpected keyword argument '{name}'") + + hint = hints[name] + if not isinstance(value, hint): # type: ignore + raise TypeError( + f"Unexpected type for keyword argument '{name}' " + f"(expected '{hint.__name__}', got '{type(value).__name__}')" + ) + + setattr(result, name, value) + + return result diff --git a/tests/compilation/test_configuration.py b/tests/compilation/test_configuration.py index d628057f4..8be948c5b 100644 --- a/tests/compilation/test_configuration.py +++ b/tests/compilation/test_configuration.py @@ -26,3 +26,47 @@ def test_configuration_bad_init(kwargs, expected_error, expected_message): Configuration(**kwargs) assert str(excinfo.value) == expected_message + + +def test_configuration_fork(): + """ + Test `fork` method of `Configuration` class. + """ + + config1 = Configuration(enable_unsafe_features=True, loop_parallelize=False) + config2 = config1.fork(enable_unsafe_features=False, loop_parallelize=True) + + assert config1 is not config2 + + assert config1.enable_unsafe_features is True + assert config1.loop_parallelize is False + + assert config2.enable_unsafe_features is False + assert config2.loop_parallelize is True + + +@pytest.mark.parametrize( + "kwargs,expected_error,expected_message", + [ + pytest.param( + {"foo": False}, + TypeError, + "Unexpected keyword argument 'foo'", + ), + pytest.param( + {"dump_artifacts_on_unexpected_failures": "yes"}, + TypeError, + "Unexpected type for keyword argument 'dump_artifacts_on_unexpected_failures' " + "(expected 'bool', got 'str')", + ), + ], +) +def test_configuration_bad_fork(kwargs, expected_error, expected_message): + """ + Test `fork` method of `Configuration` class with bad parameters. + """ + + with pytest.raises(expected_error) as excinfo: + Configuration().fork(**kwargs) + + assert str(excinfo.value) == expected_message