From 1472c8f02066f4760370b6ebacffa9b725ea8327 Mon Sep 17 00:00:00 2001 From: Umut Date: Fri, 18 Nov 2022 14:32:05 +0100 Subject: [PATCH] feat: make both p error and global p error optional --- concrete/numpy/compilation/configuration.py | 13 +++++++++++-- concrete/numpy/compilation/server.py | 17 ++++++++++++++--- docs/howto/configure.md | 6 +++--- tests/compilation/test_configuration.py | 6 ++++++ 4 files changed, 34 insertions(+), 8 deletions(-) diff --git a/concrete/numpy/compilation/configuration.py b/concrete/numpy/compilation/configuration.py index f8c64de51..8c89a431e 100644 --- a/concrete/numpy/compilation/configuration.py +++ b/concrete/numpy/compilation/configuration.py @@ -27,7 +27,7 @@ class Configuration: auto_parallelize: bool jit: bool p_error: Optional[float] - global_p_error: float + global_p_error: Optional[float] insecure_key_cache_location: Optional[str] auto_adjust_rounders: bool @@ -73,7 +73,7 @@ class Configuration: auto_parallelize: bool = False, jit: bool = False, p_error: Optional[float] = None, - global_p_error: float = (1 / 100_000), + global_p_error: Optional[float] = (1 / 100_000), auto_adjust_rounders: bool = False, ): self.verbose = verbose @@ -112,6 +112,8 @@ class Configuration: configuration that is forked from self and updated using kwargs """ + # pylint: disable=too-many-branches + result = deepcopy(self) hints = get_type_hints(Configuration) @@ -133,6 +135,11 @@ class Configuration: is_correctly_typed = False expected = "Optional[float]" + elif name == "global_p_error": + if not (value is None or isinstance(value, float)): + is_correctly_typed = False + expected = "Optional[float]" + elif name in ["show_graph", "show_mlir", "show_optimizer"]: if not (value is None or isinstance(value, bool)): is_correctly_typed = False @@ -156,3 +163,5 @@ class Configuration: # pylint: enable=protected-access return result + + # pylint: enable=too-many-branches diff --git a/concrete/numpy/compilation/server.py b/concrete/numpy/compilation/server.py index 7aeb81e1c..816b44995 100644 --- a/concrete/numpy/compilation/server.py +++ b/concrete/numpy/compilation/server.py @@ -94,10 +94,21 @@ class Server: options.set_loop_parallelize(configuration.loop_parallelize) options.set_dataflow_parallelize(configuration.dataflow_parallelize) options.set_auto_parallelize(configuration.auto_parallelize) - if configuration.p_error is not None: - options.set_p_error(configuration.p_error) - else: + + global_p_error_is_set = configuration.global_p_error is not None + p_error_is_set = configuration.p_error is not None + + if global_p_error_is_set and p_error_is_set: # pragma: no cover options.set_global_p_error(configuration.global_p_error) + options.set_p_error(configuration.p_error) + + elif global_p_error_is_set: # pragma: no cover + options.set_global_p_error(configuration.global_p_error) + options.set_p_error(1.0) + + elif p_error_is_set: # pragma: no cover + options.set_global_p_error(1.0) + options.set_p_error(configuration.p_error) show_optimizer = ( configuration.show_optimizer diff --git a/docs/howto/configure.md b/docs/howto/configure.md index 2124b6752..3e8ddc69d 100644 --- a/docs/howto/configure.md +++ b/docs/howto/configure.md @@ -74,10 +74,10 @@ Additional kwarg to `compile` function have higher precedence. So if you set an * Whether to adjust rounders automatically. * **p_error**: Optional[float] = None - * Error probability for individual table lookups. Overwrites **global_p_error** if set. + * Error probability for individual table lookups. If set, all table lookups will have the probability of non-exact result smaller than the set value. -* **global_p_error**: float = (1 / 100_000) - * Global error probability for the whole circuit. +* **global_p_error**: Optional[float] = (1 / 100_000) + * Global error probability for the whole circuit. If set, the whole circuit will have the probability of non-exact result smaller than the set value. * **jit**: bool = False * Whether to use JIT compilation. diff --git a/tests/compilation/test_configuration.py b/tests/compilation/test_configuration.py index 464f83f53..085c70773 100644 --- a/tests/compilation/test_configuration.py +++ b/tests/compilation/test_configuration.py @@ -80,6 +80,12 @@ def test_configuration_fork(): "Unexpected type for keyword argument 'p_error' " "(expected 'Optional[float]', got 'str')", ), + pytest.param( + {"global_p_error": "mamma mia"}, + TypeError, + "Unexpected type for keyword argument 'global_p_error' " + "(expected 'Optional[float]', got 'str')", + ), pytest.param( {"show_optimizer": "please"}, TypeError,