mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
fix: make sure path to keyset cache is set when enabling the cache
This commit is contained in:
@@ -48,14 +48,13 @@ class Circuit:
|
||||
|
||||
self.server = Server.create(mlir, output_signs, self.configuration)
|
||||
|
||||
keyset_cache = None
|
||||
keyset_cache_directory = None
|
||||
if self.configuration.use_insecure_key_cache:
|
||||
assert_that(self.configuration.enable_unsafe_features)
|
||||
location = Configuration.insecure_key_cache_location()
|
||||
if location is not None:
|
||||
keyset_cache = str(location)
|
||||
assert_that(self.configuration.insecure_keycache_location is not None)
|
||||
keyset_cache_directory = self.configuration.insecure_keycache_location
|
||||
|
||||
self.client = Client(self.server.client_specs, keyset_cache)
|
||||
self.client = Client(self.server.client_specs, keyset_cache_directory)
|
||||
|
||||
def __str__(self):
|
||||
return self.graph.format()
|
||||
|
||||
@@ -3,9 +3,8 @@ Declaration of `Configuration` class.
|
||||
"""
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Optional, get_type_hints
|
||||
|
||||
_INSECURE_KEY_CACHE_LOCATION: Optional[str] = None
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union, get_type_hints
|
||||
|
||||
|
||||
class Configuration:
|
||||
@@ -27,6 +26,7 @@ class Configuration:
|
||||
auto_parallelize: bool
|
||||
jit: bool
|
||||
p_error: float
|
||||
insecure_keycache_location: Optional[str]
|
||||
|
||||
# pylint: enable=too-many-instance-attributes
|
||||
|
||||
@@ -47,6 +47,11 @@ class Configuration:
|
||||
"Virtual compilation is not allowed without enabling unsafe features"
|
||||
)
|
||||
|
||||
if self.use_insecure_key_cache and self.insecure_keycache_location is None:
|
||||
raise RuntimeError(
|
||||
"Insecure key cache cannot be enabled without specifying its location"
|
||||
)
|
||||
|
||||
# pylint: disable=too-many-arguments
|
||||
|
||||
def __init__(
|
||||
@@ -58,6 +63,7 @@ class Configuration:
|
||||
enable_unsafe_features: bool = False,
|
||||
virtual: bool = False,
|
||||
use_insecure_key_cache: bool = False,
|
||||
insecure_keycache_location: Optional[Union[Path, str]] = None,
|
||||
loop_parallelize: bool = True,
|
||||
dataflow_parallelize: bool = False,
|
||||
auto_parallelize: bool = False,
|
||||
@@ -71,6 +77,9 @@ class Configuration:
|
||||
self.enable_unsafe_features = enable_unsafe_features
|
||||
self.virtual = virtual
|
||||
self.use_insecure_key_cache = use_insecure_key_cache
|
||||
self.insecure_keycache_location = (
|
||||
str(insecure_keycache_location) if insecure_keycache_location is not None else None
|
||||
)
|
||||
self.loop_parallelize = loop_parallelize
|
||||
self.dataflow_parallelize = dataflow_parallelize
|
||||
self.auto_parallelize = auto_parallelize
|
||||
@@ -81,18 +90,6 @@ class Configuration:
|
||||
|
||||
# pylint: enable=too-many-arguments
|
||||
|
||||
@staticmethod
|
||||
def insecure_key_cache_location() -> Optional[str]:
|
||||
"""
|
||||
Get insecure key cache location.
|
||||
|
||||
Returns:
|
||||
Optional[str]:
|
||||
insecure key cache location if configured, None otherwise
|
||||
"""
|
||||
|
||||
return _INSECURE_KEY_CACHE_LOCATION
|
||||
|
||||
def fork(self, **kwargs) -> "Configuration":
|
||||
"""
|
||||
Get a new configuration from another one specified changes.
|
||||
|
||||
@@ -8,7 +8,7 @@ from pathlib import Path
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from concrete.numpy import Client, ClientSpecs, Configuration, Server
|
||||
from concrete.numpy import Client, ClientSpecs, Server
|
||||
from concrete.numpy.compilation import compiler
|
||||
|
||||
|
||||
@@ -201,8 +201,8 @@ def test_client_server_api(helpers):
|
||||
client_specs = ClientSpecs.unserialize(serialized_client_specs)
|
||||
|
||||
clients = [
|
||||
Client(client_specs, Configuration.insecure_key_cache_location()),
|
||||
Client.load(client_path, Configuration.insecure_key_cache_location()),
|
||||
Client(client_specs, configuration.insecure_keycache_location),
|
||||
Client.load(client_path, configuration.insecure_keycache_location),
|
||||
]
|
||||
|
||||
for client in clients:
|
||||
|
||||
@@ -15,6 +15,15 @@ from concrete.numpy.compilation import Configuration
|
||||
RuntimeError,
|
||||
"Insecure key cache cannot be used without enabling unsafe features",
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"enable_unsafe_features": True,
|
||||
"use_insecure_key_cache": True,
|
||||
"insecure_keycache_location": None,
|
||||
},
|
||||
RuntimeError,
|
||||
"Insecure key cache cannot be enabled without specifying its location",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_configuration_bad_init(kwargs, expected_error, expected_message):
|
||||
|
||||
@@ -11,7 +11,8 @@ import numpy as np
|
||||
import pytest
|
||||
|
||||
import concrete.numpy as cnp
|
||||
from concrete.numpy.compilation import configuration as configuration_
|
||||
|
||||
INSECURE_KEY_CACHE_LOCATION = None
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
@@ -39,6 +40,9 @@ def pytest_sessionstart(session):
|
||||
"""
|
||||
Initialize insecure key cache.
|
||||
"""
|
||||
# pylint: disable=global-statement
|
||||
global INSECURE_KEY_CACHE_LOCATION
|
||||
# pylint: enable=global-statement
|
||||
|
||||
key_cache_location = session.config.getoption("--key-cache", default=None)
|
||||
if key_cache_location is not None:
|
||||
@@ -53,9 +57,7 @@ def pytest_sessionstart(session):
|
||||
key_cache_location.mkdir(parents=True, exist_ok=True)
|
||||
print(f"INSECURE_KEY_CACHE_LOCATION={str(key_cache_location)}")
|
||||
|
||||
# pylint: disable=protected-access
|
||||
configuration_._INSECURE_KEY_CACHE_LOCATION = str(key_cache_location)
|
||||
# pylint: enable=protected-access
|
||||
INSECURE_KEY_CACHE_LOCATION = str(key_cache_location)
|
||||
|
||||
|
||||
def pytest_sessionfinish(session, exitstatus): # pylint: disable=unused-argument
|
||||
@@ -108,6 +110,7 @@ class Helpers:
|
||||
dataflow_parallelize=False,
|
||||
auto_parallelize=False,
|
||||
jit=True,
|
||||
insecure_keycache_location=INSECURE_KEY_CACHE_LOCATION,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
Reference in New Issue
Block a user