From 4d22dec7050e5ed87433a7a20664ea54fe4a5da9 Mon Sep 17 00:00:00 2001 From: youben11 Date: Mon, 16 May 2022 11:11:28 +0100 Subject: [PATCH] fix: make sure path to keyset cache is set when enabling the cache --- concrete/numpy/compilation/circuit.py | 9 +++---- concrete/numpy/compilation/configuration.py | 27 +++++++++------------ tests/compilation/test_circuit.py | 6 ++--- tests/compilation/test_configuration.py | 9 +++++++ tests/conftest.py | 11 ++++++--- 5 files changed, 35 insertions(+), 27 deletions(-) diff --git a/concrete/numpy/compilation/circuit.py b/concrete/numpy/compilation/circuit.py index cc37de1c1..230d1cc74 100644 --- a/concrete/numpy/compilation/circuit.py +++ b/concrete/numpy/compilation/circuit.py @@ -48,14 +48,13 @@ class Circuit: self.server = Server.create(mlir, output_signs, self.configuration) - keyset_cache = None + keyset_cache_directory = None if self.configuration.use_insecure_key_cache: assert_that(self.configuration.enable_unsafe_features) - location = Configuration.insecure_key_cache_location() - if location is not None: - keyset_cache = str(location) + assert_that(self.configuration.insecure_keycache_location is not None) + keyset_cache_directory = self.configuration.insecure_keycache_location - self.client = Client(self.server.client_specs, keyset_cache) + self.client = Client(self.server.client_specs, keyset_cache_directory) def __str__(self): return self.graph.format() diff --git a/concrete/numpy/compilation/configuration.py b/concrete/numpy/compilation/configuration.py index 3face7886..59548d63e 100644 --- a/concrete/numpy/compilation/configuration.py +++ b/concrete/numpy/compilation/configuration.py @@ -3,9 +3,8 @@ Declaration of `Configuration` class. """ from copy import deepcopy -from typing import Optional, get_type_hints - -_INSECURE_KEY_CACHE_LOCATION: Optional[str] = None +from pathlib import Path +from typing import Optional, Union, get_type_hints class Configuration: @@ -27,6 +26,7 @@ class Configuration: auto_parallelize: bool jit: bool p_error: float + insecure_keycache_location: Optional[str] # pylint: enable=too-many-instance-attributes @@ -47,6 +47,11 @@ class Configuration: "Virtual compilation is not allowed without enabling unsafe features" ) + if self.use_insecure_key_cache and self.insecure_keycache_location is None: + raise RuntimeError( + "Insecure key cache cannot be enabled without specifying its location" + ) + # pylint: disable=too-many-arguments def __init__( @@ -58,6 +63,7 @@ class Configuration: enable_unsafe_features: bool = False, virtual: bool = False, use_insecure_key_cache: bool = False, + insecure_keycache_location: Optional[Union[Path, str]] = None, loop_parallelize: bool = True, dataflow_parallelize: bool = False, auto_parallelize: bool = False, @@ -71,6 +77,9 @@ class Configuration: self.enable_unsafe_features = enable_unsafe_features self.virtual = virtual self.use_insecure_key_cache = use_insecure_key_cache + self.insecure_keycache_location = ( + str(insecure_keycache_location) if insecure_keycache_location is not None else None + ) self.loop_parallelize = loop_parallelize self.dataflow_parallelize = dataflow_parallelize self.auto_parallelize = auto_parallelize @@ -81,18 +90,6 @@ class Configuration: # pylint: enable=too-many-arguments - @staticmethod - def insecure_key_cache_location() -> Optional[str]: - """ - Get insecure key cache location. - - Returns: - Optional[str]: - insecure key cache location if configured, None otherwise - """ - - return _INSECURE_KEY_CACHE_LOCATION - def fork(self, **kwargs) -> "Configuration": """ Get a new configuration from another one specified changes. diff --git a/tests/compilation/test_circuit.py b/tests/compilation/test_circuit.py index e95df7fe2..a0a8f28cf 100644 --- a/tests/compilation/test_circuit.py +++ b/tests/compilation/test_circuit.py @@ -8,7 +8,7 @@ from pathlib import Path import numpy as np import pytest -from concrete.numpy import Client, ClientSpecs, Configuration, Server +from concrete.numpy import Client, ClientSpecs, Server from concrete.numpy.compilation import compiler @@ -201,8 +201,8 @@ def test_client_server_api(helpers): client_specs = ClientSpecs.unserialize(serialized_client_specs) clients = [ - Client(client_specs, Configuration.insecure_key_cache_location()), - Client.load(client_path, Configuration.insecure_key_cache_location()), + Client(client_specs, configuration.insecure_keycache_location), + Client.load(client_path, configuration.insecure_keycache_location), ] for client in clients: diff --git a/tests/compilation/test_configuration.py b/tests/compilation/test_configuration.py index 8be948c5b..7ffbdbbe8 100644 --- a/tests/compilation/test_configuration.py +++ b/tests/compilation/test_configuration.py @@ -15,6 +15,15 @@ from concrete.numpy.compilation import Configuration RuntimeError, "Insecure key cache cannot be used without enabling unsafe features", ), + pytest.param( + { + "enable_unsafe_features": True, + "use_insecure_key_cache": True, + "insecure_keycache_location": None, + }, + RuntimeError, + "Insecure key cache cannot be enabled without specifying its location", + ), ], ) def test_configuration_bad_init(kwargs, expected_error, expected_message): diff --git a/tests/conftest.py b/tests/conftest.py index b78545c4a..5167d2079 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,7 +11,8 @@ import numpy as np import pytest import concrete.numpy as cnp -from concrete.numpy.compilation import configuration as configuration_ + +INSECURE_KEY_CACHE_LOCATION = None def pytest_addoption(parser): @@ -39,6 +40,9 @@ def pytest_sessionstart(session): """ Initialize insecure key cache. """ + # pylint: disable=global-statement + global INSECURE_KEY_CACHE_LOCATION + # pylint: enable=global-statement key_cache_location = session.config.getoption("--key-cache", default=None) if key_cache_location is not None: @@ -53,9 +57,7 @@ def pytest_sessionstart(session): key_cache_location.mkdir(parents=True, exist_ok=True) print(f"INSECURE_KEY_CACHE_LOCATION={str(key_cache_location)}") - # pylint: disable=protected-access - configuration_._INSECURE_KEY_CACHE_LOCATION = str(key_cache_location) - # pylint: enable=protected-access + INSECURE_KEY_CACHE_LOCATION = str(key_cache_location) def pytest_sessionfinish(session, exitstatus): # pylint: disable=unused-argument @@ -108,6 +110,7 @@ class Helpers: dataflow_parallelize=False, auto_parallelize=False, jit=True, + insecure_keycache_location=INSECURE_KEY_CACHE_LOCATION, ) @staticmethod