diff --git a/concrete/numpy/compilation/configuration.py b/concrete/numpy/compilation/configuration.py index 074794b26..d498cd20b 100644 --- a/concrete/numpy/compilation/configuration.py +++ b/concrete/numpy/compilation/configuration.py @@ -26,7 +26,8 @@ class Configuration: dataflow_parallelize: bool auto_parallelize: bool jit: bool - p_error: float + p_error: Optional[float] + global_p_error: float insecure_key_cache_location: Optional[str] # pylint: enable=too-many-instance-attributes @@ -70,7 +71,8 @@ class Configuration: dataflow_parallelize: bool = False, auto_parallelize: bool = False, jit: bool = False, - p_error: float = 6.3342483999973e-05, + p_error: Optional[float] = None, + global_p_error: float = (1 / 100_000), ): self.verbose = verbose self.show_graph = show_graph @@ -88,6 +90,7 @@ class Configuration: self.auto_parallelize = auto_parallelize self.jit = jit self.p_error = p_error + self.global_p_error = global_p_error self._validate() @@ -122,6 +125,11 @@ class Configuration: is_correctly_typed = False expected = "Optional[str]" + elif name == "p_error": + if not (value is None or isinstance(value, float)): + is_correctly_typed = False + expected = "Optional[float]" + elif not isinstance(value, hint): # type: ignore is_correctly_typed = False diff --git a/concrete/numpy/compilation/server.py b/concrete/numpy/compilation/server.py index f35d47174..94fb3748c 100644 --- a/concrete/numpy/compilation/server.py +++ b/concrete/numpy/compilation/server.py @@ -94,7 +94,10 @@ class Server: options.set_loop_parallelize(configuration.loop_parallelize) options.set_dataflow_parallelize(configuration.dataflow_parallelize) options.set_auto_parallelize(configuration.auto_parallelize) - options.set_p_error(configuration.p_error) + if configuration.p_error is not None: + options.set_p_error(configuration.p_error) + else: + options.set_global_p_error(configuration.global_p_error) options.set_display_optimizer_choice(configuration.verbose or configuration.show_optimizer) if configuration.jit: diff --git a/docs/howto/configure.md b/docs/howto/configure.md index f167cec41..b8071a0af 100644 --- a/docs/howto/configure.md +++ b/docs/howto/configure.md @@ -64,8 +64,11 @@ Additional kwarg to `compile` function have higher precedence. So if you set an * **dump_artifacts_on_unexpected_failures**: bool = True * Whether to export debugging artifacts automatically on compilation failures. -* **p_error**: float = 0.000063342483999973 - * Error probability for table lookups. +* **p_error**: Optional[float] = None + * Error probability for individual table lookups. Overwrites **global_p_error** if set. + +* **global_p_error**: float = (1 / 100_000) + * Global error probability for the whole circuit. * **jit**: bool = False * Whether to use JIT compilation. diff --git a/tests/compilation/test_circuit.py b/tests/compilation/test_circuit.py index 1f08e8483..1d693c655 100644 --- a/tests/compilation/test_circuit.py +++ b/tests/compilation/test_circuit.py @@ -23,7 +23,7 @@ def test_circuit_str(helpers): return x + y inputset = [(np.random.randint(0, 2**4), np.random.randint(0, 2**5)) for _ in range(100)] - circuit = f.compile(inputset, configuration) + circuit = f.compile(inputset, configuration.fork(p_error=6e-5)) assert str(circuit) == ( """ diff --git a/tests/compilation/test_configuration.py b/tests/compilation/test_configuration.py index 6194c8dfc..34632de40 100644 --- a/tests/compilation/test_configuration.py +++ b/tests/compilation/test_configuration.py @@ -74,6 +74,12 @@ def test_configuration_fork(): "Unexpected type for keyword argument 'insecure_key_cache_location' " "(expected 'Optional[str]', got 'int')", ), + pytest.param( + {"p_error": "yes"}, + TypeError, + "Unexpected type for keyword argument 'p_error' " + "(expected 'Optional[float]', got 'str')", + ), ], ) def test_configuration_bad_fork(kwargs, expected_error, expected_message): diff --git a/tests/conftest.py b/tests/conftest.py index 1ad55bfdb..b43a226df 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -111,6 +111,7 @@ class Helpers: auto_parallelize=False, jit=True, insecure_key_cache_location=INSECURE_KEY_CACHE_LOCATION, + global_p_error=(1 / 10_000), ) @staticmethod