fix: make sure path to keyset cache is set when enabling the cache

This commit is contained in:
youben11
2022-05-16 11:11:28 +01:00
committed by Ayoub Benaissa
parent c9bb05df82
commit 4d22dec705
5 changed files with 35 additions and 27 deletions

View File

@@ -48,14 +48,13 @@ class Circuit:
self.server = Server.create(mlir, output_signs, self.configuration)
keyset_cache = None
keyset_cache_directory = None
if self.configuration.use_insecure_key_cache:
assert_that(self.configuration.enable_unsafe_features)
location = Configuration.insecure_key_cache_location()
if location is not None:
keyset_cache = str(location)
assert_that(self.configuration.insecure_keycache_location is not None)
keyset_cache_directory = self.configuration.insecure_keycache_location
self.client = Client(self.server.client_specs, keyset_cache)
self.client = Client(self.server.client_specs, keyset_cache_directory)
def __str__(self):
return self.graph.format()

View File

@@ -3,9 +3,8 @@ Declaration of `Configuration` class.
"""
from copy import deepcopy
from typing import Optional, get_type_hints
_INSECURE_KEY_CACHE_LOCATION: Optional[str] = None
from pathlib import Path
from typing import Optional, Union, get_type_hints
class Configuration:
@@ -27,6 +26,7 @@ class Configuration:
auto_parallelize: bool
jit: bool
p_error: float
insecure_keycache_location: Optional[str]
# pylint: enable=too-many-instance-attributes
@@ -47,6 +47,11 @@ class Configuration:
"Virtual compilation is not allowed without enabling unsafe features"
)
if self.use_insecure_key_cache and self.insecure_keycache_location is None:
raise RuntimeError(
"Insecure key cache cannot be enabled without specifying its location"
)
# pylint: disable=too-many-arguments
def __init__(
@@ -58,6 +63,7 @@ class Configuration:
enable_unsafe_features: bool = False,
virtual: bool = False,
use_insecure_key_cache: bool = False,
insecure_keycache_location: Optional[Union[Path, str]] = None,
loop_parallelize: bool = True,
dataflow_parallelize: bool = False,
auto_parallelize: bool = False,
@@ -71,6 +77,9 @@ class Configuration:
self.enable_unsafe_features = enable_unsafe_features
self.virtual = virtual
self.use_insecure_key_cache = use_insecure_key_cache
self.insecure_keycache_location = (
str(insecure_keycache_location) if insecure_keycache_location is not None else None
)
self.loop_parallelize = loop_parallelize
self.dataflow_parallelize = dataflow_parallelize
self.auto_parallelize = auto_parallelize
@@ -81,18 +90,6 @@ class Configuration:
# pylint: enable=too-many-arguments
@staticmethod
def insecure_key_cache_location() -> Optional[str]:
"""
Get insecure key cache location.
Returns:
Optional[str]:
insecure key cache location if configured, None otherwise
"""
return _INSECURE_KEY_CACHE_LOCATION
def fork(self, **kwargs) -> "Configuration":
"""
Get a new configuration from another one specified changes.

View File

@@ -8,7 +8,7 @@ from pathlib import Path
import numpy as np
import pytest
from concrete.numpy import Client, ClientSpecs, Configuration, Server
from concrete.numpy import Client, ClientSpecs, Server
from concrete.numpy.compilation import compiler
@@ -201,8 +201,8 @@ def test_client_server_api(helpers):
client_specs = ClientSpecs.unserialize(serialized_client_specs)
clients = [
Client(client_specs, Configuration.insecure_key_cache_location()),
Client.load(client_path, Configuration.insecure_key_cache_location()),
Client(client_specs, configuration.insecure_keycache_location),
Client.load(client_path, configuration.insecure_keycache_location),
]
for client in clients:

View File

@@ -15,6 +15,15 @@ from concrete.numpy.compilation import Configuration
RuntimeError,
"Insecure key cache cannot be used without enabling unsafe features",
),
pytest.param(
{
"enable_unsafe_features": True,
"use_insecure_key_cache": True,
"insecure_keycache_location": None,
},
RuntimeError,
"Insecure key cache cannot be enabled without specifying its location",
),
],
)
def test_configuration_bad_init(kwargs, expected_error, expected_message):

View File

@@ -11,7 +11,8 @@ import numpy as np
import pytest
import concrete.numpy as cnp
from concrete.numpy.compilation import configuration as configuration_
INSECURE_KEY_CACHE_LOCATION = None
def pytest_addoption(parser):
@@ -39,6 +40,9 @@ def pytest_sessionstart(session):
"""
Initialize insecure key cache.
"""
# pylint: disable=global-statement
global INSECURE_KEY_CACHE_LOCATION
# pylint: enable=global-statement
key_cache_location = session.config.getoption("--key-cache", default=None)
if key_cache_location is not None:
@@ -53,9 +57,7 @@ def pytest_sessionstart(session):
key_cache_location.mkdir(parents=True, exist_ok=True)
print(f"INSECURE_KEY_CACHE_LOCATION={str(key_cache_location)}")
# pylint: disable=protected-access
configuration_._INSECURE_KEY_CACHE_LOCATION = str(key_cache_location)
# pylint: enable=protected-access
INSECURE_KEY_CACHE_LOCATION = str(key_cache_location)
def pytest_sessionfinish(session, exitstatus): # pylint: disable=unused-argument
@@ -108,6 +110,7 @@ class Helpers:
dataflow_parallelize=False,
auto_parallelize=False,
jit=True,
insecure_keycache_location=INSECURE_KEY_CACHE_LOCATION,
)
@staticmethod