feat(frontend/python): explicit key management

This commit is contained in:
Umut
2023-04-11 10:35:13 +02:00
committed by Quentin Bourgerie
parent 22a2407b60
commit 673b02473f
17 changed files with 626 additions and 138 deletions

View File

@@ -16,6 +16,7 @@ from .compilation import (
Configuration,
DebugArtifacts,
EncryptionStatus,
Keys,
Server,
)
from .compilation.decorators import circuit, compiler

View File

@@ -7,5 +7,6 @@ from .circuit import Circuit
from .client import Client
from .compiler import Compiler, EncryptionStatus
from .configuration import DEFAULT_GLOBAL_P_ERROR, DEFAULT_P_ERROR, Configuration
from .keys import Keys
from .server import Server
from .specs import ClientSpecs

View File

@@ -4,18 +4,18 @@ Declaration of `Circuit` class.
# pylint: disable=import-error,no-member,no-name-in-module
from typing import Any, Optional, Tuple, Union, cast
from typing import Any, Optional, Tuple, Union
import numpy as np
# mypy: disable-error-code=attr-defined
from concrete.compiler import PublicArguments, PublicResult
from ..dtypes import Integer
from ..internal.utils import assert_that
from ..representation import Graph
from .client import Client
from .configuration import Configuration
from .keys import Keys
from .server import Server
# pylint: enable=import-error,no-member,no-name-in-module
@@ -43,21 +43,7 @@ class Circuit:
self._initialize_client_and_server()
def _initialize_client_and_server(self):
input_signs = []
for i in range(len(self.graph.input_nodes)): # pylint: disable=consider-using-enumerate
input_value = self.graph.input_nodes[i].output
assert_that(isinstance(input_value.dtype, Integer))
input_dtype = cast(Integer, input_value.dtype)
input_signs.append(input_dtype.is_signed)
output_signs = []
for i in range(len(self.graph.output_nodes)): # pylint: disable=consider-using-enumerate
output_value = self.graph.output_nodes[i].output
assert_that(isinstance(output_value.dtype, Integer))
output_dtype = cast(Integer, output_value.dtype)
output_signs.append(output_dtype.is_signed)
self.server = Server.create(self.mlir, input_signs, output_signs, self.configuration)
self.server = Server.create(self.mlir, self.configuration)
keyset_cache_directory = None
if self.configuration.use_insecure_key_cache:
@@ -85,16 +71,33 @@ class Circuit:
return self.graph(*args, p_error=self.p_error)
def keygen(self, force: bool = False):
@property
def keys(self) -> Keys:
"""
Get the keys of the circuit.
"""
return self.client.keys
@keys.setter
def keys(self, new_keys: Keys):
"""
Set the keys of the circuit.
"""
self.client.keys = new_keys
def keygen(self, force: bool = False, seed: Optional[int] = None):
"""
Generate keys required for homomorphic evaluation.
Args:
force (bool, default = False):
whether to generate new keys even if keys are already generated
seed (Optional[int], default = None):
seed for randomness
"""
self.client.keygen(force)
self.client.keygen(force, seed)
def encrypt(self, *args: Union[int, np.ndarray]) -> PublicArguments:
"""

View File

@@ -11,20 +11,12 @@ from pathlib import Path
from typing import Dict, Optional, Tuple, Union
import numpy as np
# mypy: disable-error-code=attr-defined
from concrete.compiler import (
ClientSupport,
EvaluationKeys,
KeySet,
KeySetCache,
PublicArguments,
PublicResult,
)
from concrete.compiler import ClientSupport, EvaluationKeys, PublicArguments, PublicResult
from ..dtypes.integer import SignedInteger, UnsignedInteger
from ..internal.utils import assert_that
from ..values.value import Value
from .keys import Keys
from .specs import ClientSpecs
# pylint: enable=import-error,no-member,no-name-in-module
@@ -36,9 +28,7 @@ class Client:
"""
specs: ClientSpecs
_keyset: Optional[KeySet]
_keyset_cache: Optional[KeySetCache]
_keys: Keys
def __init__(
self,
@@ -46,12 +36,7 @@ class Client:
keyset_cache_directory: Optional[Union[str, Path]] = None,
):
self.specs = client_specs
self._keyset = None
self._keyset_cache = None
if keyset_cache_directory is not None:
self._keyset_cache = KeySetCache.new(str(keyset_cache_directory))
self._keys = Keys(client_specs, keyset_cache_directory)
def save(self, path: Union[str, Path]):
"""
@@ -63,7 +48,7 @@ class Client:
"""
with tempfile.TemporaryDirectory() as tmp_dir:
with open(Path(tmp_dir) / "client.specs.json", "w", encoding="utf-8") as f:
with open(Path(tmp_dir) / "client.specs.json", "wb") as f:
f.write(self.specs.serialize())
path = str(path)
@@ -94,22 +79,42 @@ class Client:
with tempfile.TemporaryDirectory() as tmp_dir:
shutil.unpack_archive(path, tmp_dir, "zip")
with open(Path(tmp_dir) / "client.specs.json", "r", encoding="utf-8") as f:
with open(Path(tmp_dir) / "client.specs.json", "rb") as f:
client_specs = ClientSpecs.deserialize(f.read())
return Client(client_specs, keyset_cache_directory)
def keygen(self, force: bool = False):
@property
def keys(self) -> Keys:
"""
Get the keys for the client.
"""
return self._keys
@keys.setter
def keys(self, new_keys: Keys):
"""
Set the keys for the client.
"""
if new_keys.client_specs != self.specs:
message = "Unable to set keys as they are generated for a different circuit"
raise ValueError(message)
self._keys = new_keys
def keygen(self, force: bool = False, seed: Optional[int] = None):
"""
Generate keys required for homomorphic evaluation.
Args:
force (bool, default = False):
whether to generate new keys even if keys are already generated
seed (Optional[int], default = None):
seed for randomness
"""
if self._keyset is None or force:
self._keyset = ClientSupport.key_set(self.specs.client_parameters, self._keyset_cache)
self.keys.generate(force=force, seed=seed)
def encrypt(self, *args: Union[int, np.ndarray]) -> PublicArguments:
"""
@@ -143,12 +148,11 @@ class Client:
)
width = spec["shape"]["width"]
is_signed = spec["shape"]["sign"]
shape = tuple(spec["shape"]["dimensions"])
is_encrypted = spec["encryption"] is not None
expected_dtype = (
SignedInteger(width) if self.specs.input_signs[index] else UnsignedInteger(width)
)
expected_dtype = SignedInteger(width) if is_signed else UnsignedInteger(width)
expected_value = Value(expected_dtype, shape, is_encrypted)
if is_valid:
expected_min = expected_dtype.min()
@@ -175,9 +179,11 @@ class Client:
raise ValueError(message)
self.keygen(force=False)
keyset = self.keys._keyset # pylint: disable=protected-access
return ClientSupport.encrypt_arguments(
self.specs.client_parameters,
self._keyset,
keyset,
[sanitized_args[i] for i in range(len(sanitized_args))],
)
@@ -186,19 +192,20 @@ class Client:
result: PublicResult,
) -> Union[int, np.ndarray, Tuple[Union[int, np.ndarray], ...]]:
"""
Decrypt result of homomorphic evaluaton.
Decrypt result of homomorphic evaluation.
Args:
result (PublicResult):
encrypted result of homomorphic evaluaton
encrypted result of homomorphic evaluation
Returns:
Union[int, numpy.ndarray]:
clear result of homomorphic evaluaton
clear result of homomorphic evaluation
"""
self.keygen(force=False)
outputs = ClientSupport.decrypt_result(self.specs.client_parameters, self._keyset, result)
keyset = self.keys._keyset # pylint: disable=protected-access
outputs = ClientSupport.decrypt_result(self.specs.client_parameters, keyset, result)
return outputs
@property
@@ -212,6 +219,4 @@ class Client:
"""
self.keygen(force=False)
assert self._keyset is not None
return self._keyset.get_evaluation_keys()
return self.keys.evaluation

View File

@@ -0,0 +1,196 @@
"""
Declaration of `Keys` class.
"""
# pylint: disable=import-error,no-name-in-module
import pathlib
from pathlib import Path
from typing import Optional, Union
from concrete.compiler import ClientSupport, EvaluationKeys, KeySet, KeySetCache
from .specs import ClientSpecs
# pylint: enable=import-error,no-name-in-module
class Keys:
"""
Keys class, to manage generate/reuse keys.
Includes encryption keys as well as evaluation keys.
Be careful when serializing/saving keys!
"""
client_specs: ClientSpecs
cache_directory: Optional[Union[str, Path]]
_keyset_cache: Optional[KeySetCache]
_keyset: Optional[KeySet]
def __init__(
self,
client_specs: ClientSpecs,
cache_directory: Optional[Union[str, Path]] = None,
):
self.client_specs = client_specs
self.cache_directory = cache_directory
self._keyset_cache = None
self._keyset = None
if cache_directory is not None:
self._keyset_cache = KeySetCache.new(str(cache_directory))
def generate(self, force: bool = False, seed: Optional[int] = None):
"""
Generate new keys.
Args:
force (bool, default = False):
whether to generate new keys even if keys are already generated/loaded
seed (Optional[int], default = None):
seed for randomness
"""
# seed of 0 will result in a crypto secure randomly generated 128-bit seed
seed_msb = 0
seed_lsb = 0
if seed is not None:
seed_lsb = seed & ((2**64) - 1)
seed_msb = (seed >> 64) & ((2**64) - 1)
if self._keyset is None or force:
self._keyset = ClientSupport.key_set(
self.client_specs.client_parameters,
self._keyset_cache,
seed_msb,
seed_lsb,
)
def save(self, location: Union[str, Path]):
"""
Save keys to a location.
Saved keys are not encrypted, so be careful how you store/transfer them!
Args:
location (Union[str, Path]):
location to save to
"""
if not isinstance(location, Path):
location = pathlib.Path(location)
if location.exists():
message = f"Unable to save keys to {location} because it already exists"
raise ValueError(message)
location.write_bytes(self.serialize())
def load(self, location: Union[str, Path]):
"""
Load keys from a location.
Args:
location (Union[str, Path]):
location to load from
"""
if not isinstance(location, Path):
location = pathlib.Path(location)
if not location.exists():
message = f"Unable to load keys from {location} because it doesn't exist"
raise ValueError(message)
keys = Keys.deserialize(bytes(location.read_bytes()))
self.client_specs = keys.client_specs
self.cache_directory = None
# pylint: disable=protected-access
self._keyset_cache = None
self._keyset = keys._keyset
# pylint: enable=protected-access
def load_if_exists_generate_and_save_otherwise(
self,
location: Union[str, Path],
seed: Optional[int] = None,
):
"""
Load keys from a location if they exist, else generate new keys and save to that location.
Args:
location (Union[str, Path]):
location to load from or save to
seed (Optional[int], default = None):
seed for randomness in case keys need to be generated
"""
if not isinstance(location, Path):
location = pathlib.Path(location)
if location.exists():
self.load(location)
else:
self.generate(seed=seed)
self.save(location)
def serialize(self) -> bytes:
"""
Serialize keys into bytes.
Serialized keys are not encrypted, so be careful how you store/transfer them!
Returns:
bytes:
serialized keys
"""
if self._keyset is None:
message = "Keys cannot be serialized before they are generated"
raise RuntimeError(message)
serialized_keyset = self._keyset.serialize()
return serialized_keyset
@staticmethod
def deserialize(serialized_keys: bytes) -> "Keys":
"""
Deserialize keys from bytes.
Args:
serialized_keys (bytes):
previously serialized keys
Returns:
Keys:
deserialized keys
"""
keyset = KeySet.deserialize(serialized_keys)
client_specs = ClientSpecs(keyset.client_parameters())
# pylint: disable=protected-access
result = Keys(client_specs)
result._keyset = keyset
# pylint: enable=protected-access
return result
@property
def evaluation(self) -> EvaluationKeys:
"""
Get only evaluation keys.
"""
self.generate(force=False)
assert self._keyset is not None
return self._keyset.get_evaluation_keys()

View File

@@ -8,7 +8,7 @@ import json
import shutil
import tempfile
from pathlib import Path
from typing import List, Optional, Union
from typing import Optional, Union
# mypy: disable-error-code=attr-defined
import concrete.compiler
@@ -72,12 +72,7 @@ class Server:
)
@staticmethod
def create(
mlir: str,
input_signs: List[bool],
output_signs: List[bool],
configuration: Configuration,
) -> "Server":
def create(mlir: str, configuration: Configuration) -> "Server":
"""
Create a server using MLIR and output sign information.
@@ -85,12 +80,6 @@ class Server:
mlir (str):
mlir to compile
input_signs (List[bool]):
sign status of the inputs
output_signs (List[bool]):
sign status of the outputs
configuration (Optional[Configuration], default = None):
configuration to use
"""
@@ -159,7 +148,7 @@ class Server:
server_lambda = support.load_server_lambda(compilation_result)
client_parameters = support.load_client_parameters(compilation_result)
client_specs = ClientSpecs(input_signs, client_parameters, output_signs)
client_specs = ClientSpecs(client_parameters)
result = Server(client_specs, output_dir, support, compilation_result, server_lambda)
@@ -196,12 +185,6 @@ class Server:
with open(Path(tmp) / "circuit.mlir", "w", encoding="utf-8") as f:
f.write(self._mlir)
with open(Path(tmp) / "input_signs.json", "w", encoding="utf-8") as f:
f.write(json.dumps(self.client_specs.input_signs))
with open(Path(tmp) / "output_signs.json", "w", encoding="utf-8") as f:
f.write(json.dumps(self.client_specs.output_signs))
with open(Path(tmp) / "configuration.json", "w", encoding="utf-8") as f:
f.write(json.dumps(self._configuration.__dict__))
@@ -213,7 +196,7 @@ class Server:
message = "Just-in-Time compilation cannot be saved"
raise RuntimeError(message)
with open(Path(self._output_dir.name) / "client.specs.json", "w", encoding="utf-8") as f:
with open(Path(self._output_dir.name) / "client.specs.json", "wb") as f:
f.write(self.client_specs.serialize())
shutil.make_archive(path, "zip", self._output_dir.name)
@@ -243,22 +226,12 @@ class Server:
with open(output_dir_path / "circuit.mlir", "r", encoding="utf-8") as f:
mlir = f.read()
with open(output_dir_path / "input_signs.json", "r", encoding="utf-8") as f:
input_signs = json.load(f)
assert_that(isinstance(input_signs, list))
assert_that(all(isinstance(sign, bool) for sign in input_signs))
with open(output_dir_path / "output_signs.json", "r", encoding="utf-8") as f:
output_signs = json.load(f)
assert_that(isinstance(output_signs, list))
assert_that(all(isinstance(sign, bool) for sign in output_signs))
with open(output_dir_path / "configuration.json", "r", encoding="utf-8") as f:
configuration = Configuration().fork(**json.load(f))
return Server.create(mlir, input_signs, output_signs, configuration)
return Server.create(mlir, configuration)
with open(output_dir_path / "client.specs.json", "r", encoding="utf-8") as f:
with open(output_dir_path / "client.specs.json", "rb") as f:
client_specs = ClientSpecs.deserialize(f.read())
support = LibrarySupport.new(

View File

@@ -4,8 +4,7 @@ Declaration of `ClientSpecs` class.
# pylint: disable=import-error,no-member,no-name-in-module
import json
from typing import List
from typing import Any
# mypy: disable-error-code=attr-defined
from concrete.compiler import ClientParameters, PublicArguments, PublicResult
@@ -18,45 +17,35 @@ class ClientSpecs:
ClientSpecs class, to create Client objects.
"""
input_signs: List[bool]
client_parameters: ClientParameters
output_signs: List[bool]
def __init__(
self,
input_signs: List[bool],
client_parameters: ClientParameters,
output_signs: List[bool],
):
self.input_signs = input_signs
def __init__(self, client_parameters: ClientParameters):
self.client_parameters = client_parameters
self.output_signs = output_signs
def serialize(self) -> str:
def __eq__(self, other: Any):
if self.client_parameters.serialize() != other.client_parameters.serialize():
return False
return True
def serialize(self) -> bytes:
"""
Serialize client specs into a string representation.
Returns:
str:
string representation of the client specs
bytes:
serialized client specs
"""
client_parameters_json = json.loads(self.client_parameters.serialize())
return json.dumps(
{
"input_signs": self.input_signs,
"client_parameters": client_parameters_json,
"output_signs": self.output_signs,
}
)
return self.client_parameters.serialize()
@staticmethod
def deserialize(serialized_client_specs: str) -> "ClientSpecs":
def deserialize(serialized_client_specs: bytes) -> "ClientSpecs":
"""
Create client specs from its string representation.
Args:
serialized_client_specs (str):
serialized_client_specs (bytes):
client specs to deserialize
Returns:
@@ -64,12 +53,8 @@ class ClientSpecs:
deserialized client specs
"""
raw_specs = json.loads(serialized_client_specs)
client_parameters_bytes = json.dumps(raw_specs["client_parameters"]).encode("utf-8")
client_parameters = ClientParameters.deserialize(client_parameters_bytes)
return ClientSpecs(raw_specs["input_signs"], client_parameters, raw_specs["output_signs"])
client_parameters = ClientParameters.deserialize(serialized_client_specs)
return ClientSpecs(client_parameters)
def serialize_public_args(self, args: PublicArguments) -> bytes:
"""
@@ -88,7 +73,7 @@ class ClientSpecs:
def deserialize_public_args(self, serialized_args: bytes) -> PublicArguments:
"""
Unserialize public arguments from bytes.
Deserialize public arguments from bytes.
Args:
serialized_args (bytes):
@@ -118,7 +103,7 @@ class ClientSpecs:
def deserialize_public_result(self, serialized_result: bytes) -> PublicResult:
"""
Unserialize public result from bytes.
Deserialize public result from bytes.
Args:
serialized_result (bytes):

View File

@@ -0,0 +1,192 @@
"""
Tests of `Keys` class.
"""
import tempfile
from pathlib import Path
import pytest
from concrete import fhe
def test_keys_save_load(helpers):
"""
Test saving and loading keys.
"""
@fhe.compiler({"x": "encrypted"})
def f(x):
return x**2
inputset = range(10)
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_dir_path = Path(tmp_dir)
keys_path = tmp_dir_path / "keys"
circuit1 = f.compile(inputset, helpers.configuration().fork(use_insecure_key_cache=False))
circuit1.keygen()
sample = circuit1.encrypt(5)
evaluation = circuit1.run(sample)
circuit1.keys.save(str(keys_path))
circuit2 = f.compile(inputset, helpers.configuration().fork(use_insecure_key_cache=False))
circuit2.keys.load(str(keys_path))
assert circuit2.decrypt(evaluation) == 25
def test_keys_bad_save_load(helpers):
"""
Test saving/loading keys where location is (not) empty.
"""
@fhe.compiler({"x": "encrypted"})
def f(x):
return x**2
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_dir_path = Path(tmp_dir)
keys_path = tmp_dir_path / "keys"
inputset = range(10)
circuit = f.compile(inputset, helpers.configuration().fork(use_insecure_key_cache=False))
with pytest.raises(ValueError) as excinfo:
circuit.keys.load(keys_path)
expected_message = f"Unable to load keys from {keys_path} because it doesn't exist"
helpers.check_str(expected_message, str(excinfo.value))
with open(keys_path, "w", encoding="utf-8") as f:
f.write("foo")
circuit.keys.generate()
with pytest.raises(ValueError) as excinfo:
circuit.keys.save(keys_path)
expected_message = f"Unable to save keys to {keys_path} because it already exists"
helpers.check_str(expected_message, str(excinfo.value))
def test_keys_load_if_exists_generate_and_save_otherwise(helpers):
"""
Test saving and loading keys using `load_if_exists_generate_and_save_otherwise`.
"""
@fhe.compiler({"x": "encrypted"})
def f(x):
return x**2
inputset = range(10)
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_dir_path = Path(tmp_dir)
keys_path = tmp_dir_path / "keys"
circuit1 = f.compile(inputset, helpers.configuration().fork(use_insecure_key_cache=False))
circuit1.keys.load_if_exists_generate_and_save_otherwise(str(keys_path))
sample = circuit1.encrypt(5)
evaluation = circuit1.run(sample)
circuit2 = f.compile(inputset, helpers.configuration().fork(use_insecure_key_cache=False))
circuit2.keys.load_if_exists_generate_and_save_otherwise(str(keys_path))
assert circuit2.decrypt(evaluation) == 25
def test_keys_serialize_deserialize(helpers):
"""
Test serializing and deserializing keys.
"""
@fhe.compiler({"x": "encrypted"})
def f(x):
return x**2
inputset = range(10)
circuit = f.compile(inputset, helpers.configuration())
server = circuit.server
client1 = fhe.Client(server.client_specs)
client1.keys.generate()
sample = client1.encrypt(5)
evaluation = server.run(sample, client1.evaluation_keys)
client2 = fhe.Client(server.client_specs)
client2.keys = fhe.Keys.deserialize(client1.keys.serialize())
assert client2.decrypt(evaluation) == 25
def test_keys_serialize_before_generation(helpers):
"""
Test serialization of keys before their generation.
"""
@fhe.compiler({"x": "encrypted"})
def f(x):
return x + 42
inputset = range(10)
circuit = f.compile(inputset, configuration=helpers.configuration())
with pytest.raises(RuntimeError) as excinfo:
circuit.keys.serialize()
expected_message = "Keys cannot be serialized before they are generated"
helpers.check_str(expected_message, str(excinfo.value))
def test_keys_generate_manual_seed(helpers):
"""
Test key generation with custom seed.
"""
@fhe.compiler({"x": "encrypted"})
def f(x):
return x**2
inputset = range(10)
circuit = f.compile(inputset, helpers.configuration().fork(use_insecure_key_cache=False))
circuit.keygen(seed=42)
sample = circuit.encrypt(5)
evaluation = circuit.run(sample)
same_circuit = f.compile(inputset, helpers.configuration().fork(use_insecure_key_cache=False))
same_circuit.keygen(seed=42)
assert same_circuit.decrypt(evaluation) == 25
def test_assign_keys_with_different_parameters(helpers):
"""
Test assigning incompatible keys to a circuit.
"""
@fhe.compiler({"x": "encrypted"})
def f(x):
return x + 42
@fhe.compiler({"x": "encrypted"})
def g(x):
return x**2
f_circuit = f.compile(inputset=range(99), configuration=helpers.configuration())
g_circuit = g.compile(inputset=range(10), configuration=helpers.configuration())
f_circuit.keygen()
g_circuit.keygen()
with pytest.raises(ValueError) as excinfo:
f_circuit.keys = g_circuit.keys
expected_message = "Unable to set keys as they are generated for a different circuit"
helpers.check_str(expected_message, str(excinfo.value))