mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat: allow saving servers via MLIR
This commit is contained in:
@@ -114,10 +114,20 @@ class Configuration:
|
||||
raise TypeError(f"Unexpected keyword argument '{name}'")
|
||||
|
||||
hint = hints[name]
|
||||
if not isinstance(value, hint): # type: ignore
|
||||
is_correctly_typed = True
|
||||
|
||||
if name == "insecure_key_cache_location":
|
||||
if not (value is None or isinstance(value, str)):
|
||||
is_correctly_typed = False
|
||||
|
||||
elif not isinstance(value, hint): # type: ignore
|
||||
is_correctly_typed = False
|
||||
|
||||
if not is_correctly_typed:
|
||||
expected = hint.__name__ if hasattr(hint, "__name__") else str(hint)
|
||||
raise TypeError(
|
||||
f"Unexpected type for keyword argument '{name}' "
|
||||
f"(expected '{hint.__name__}', got '{type(value).__name__}')"
|
||||
f"(expected '{expected}', got '{type(value).__name__}')"
|
||||
)
|
||||
|
||||
setattr(result, name, value)
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
Declaration of `Server` class.
|
||||
"""
|
||||
|
||||
import json
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
@@ -37,6 +38,9 @@ class Server:
|
||||
_compilation_result: Union[JITCompilationResult, LibraryCompilationResult]
|
||||
_server_lambda: Union[JITLambda, LibraryLambda]
|
||||
|
||||
_mlir: Optional[str]
|
||||
_configuration: Optional[Configuration]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client_specs: ClientSpecs,
|
||||
@@ -51,6 +55,7 @@ class Server:
|
||||
self._support = support
|
||||
self._compilation_result = compilation_result
|
||||
self._server_lambda = server_lambda
|
||||
self._mlir = None
|
||||
|
||||
assert_that(
|
||||
support.load_client_parameters(compilation_result).serialize()
|
||||
@@ -112,27 +117,61 @@ class Server:
|
||||
|
||||
client_parameters = support.load_client_parameters(compilation_result)
|
||||
client_specs = ClientSpecs(input_signs, client_parameters, output_signs)
|
||||
return Server(client_specs, output_dir, support, compilation_result, server_lambda)
|
||||
|
||||
def save(self, path: Union[str, Path]):
|
||||
result = Server(client_specs, output_dir, support, compilation_result, server_lambda)
|
||||
|
||||
# pylint: disable=protected-access
|
||||
result._mlir = mlir
|
||||
result._configuration = configuration
|
||||
# pylint: enable=protected-access
|
||||
|
||||
return result
|
||||
|
||||
def save(self, path: Union[str, Path], via_mlir: bool = False):
|
||||
"""
|
||||
Save the server into the given path in zip format.
|
||||
|
||||
Args:
|
||||
path (Union[str, Path]):
|
||||
path to save the server
|
||||
|
||||
via_mlir (bool, default = False)
|
||||
export using the MLIR code of the program,
|
||||
this will make the export cross-platform
|
||||
"""
|
||||
|
||||
path = str(path)
|
||||
if path.endswith(".zip"):
|
||||
path = path[: len(path) - 4]
|
||||
|
||||
if via_mlir:
|
||||
if self._mlir is None or self._configuration is None:
|
||||
raise RuntimeError("Loaded server objects cannot be saved again via MLIR")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
|
||||
with open(Path(tmp) / "circuit.mlir", "w", encoding="utf-8") as f:
|
||||
f.write(self._mlir)
|
||||
|
||||
with open(Path(tmp) / "input_signs.json", "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(self.client_specs.input_signs))
|
||||
|
||||
with open(Path(tmp) / "output_signs.json", "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(self.client_specs.output_signs))
|
||||
|
||||
with open(Path(tmp) / "configuration.json", "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(self._configuration.__dict__))
|
||||
|
||||
shutil.make_archive(path, "zip", tmp)
|
||||
|
||||
return
|
||||
|
||||
if self._output_dir is None:
|
||||
raise RuntimeError("Just-in-Time compilation cannot be saved")
|
||||
|
||||
with open(Path(self._output_dir.name) / "client.specs.json", "w", encoding="utf-8") as f:
|
||||
f.write(self.client_specs.serialize())
|
||||
|
||||
path = str(path)
|
||||
if path.endswith(".zip"):
|
||||
path = path[: len(path) - 4]
|
||||
|
||||
shutil.make_archive(path, "zip", self._output_dir.name)
|
||||
|
||||
@staticmethod
|
||||
@@ -156,6 +195,25 @@ class Server:
|
||||
|
||||
shutil.unpack_archive(path, str(output_dir_path), "zip")
|
||||
|
||||
if (output_dir_path / "circuit.mlir").exists():
|
||||
with open(output_dir_path / "circuit.mlir", "r", encoding="utf-8") as f:
|
||||
mlir = f.read()
|
||||
|
||||
with open(output_dir_path / "input_signs.json", "r", encoding="utf-8") as f:
|
||||
input_signs = json.load(f)
|
||||
assert_that(isinstance(input_signs, list))
|
||||
assert_that(all(isinstance(sign, bool) for sign in input_signs))
|
||||
|
||||
with open(output_dir_path / "output_signs.json", "r", encoding="utf-8") as f:
|
||||
output_signs = json.load(f)
|
||||
assert_that(isinstance(output_signs, list))
|
||||
assert_that(all(isinstance(sign, bool) for sign in output_signs))
|
||||
|
||||
with open(output_dir_path / "configuration.json", "r", encoding="utf-8") as f:
|
||||
configuration = Configuration().fork(**json.load(f))
|
||||
|
||||
return Server.create(mlir, input_signs, output_signs, configuration)
|
||||
|
||||
with open(output_dir_path / "client.specs.json", "r", encoding="utf-8") as f:
|
||||
client_specs = ClientSpecs.unserialize(f.read())
|
||||
|
||||
|
||||
@@ -182,6 +182,66 @@ def test_client_server_api(helpers):
|
||||
Client.load(client_path, configuration.insecure_key_cache_location),
|
||||
]
|
||||
|
||||
for client in clients:
|
||||
args = client.encrypt([3, 8, 1])
|
||||
|
||||
serialized_args = client.specs.serialize_public_args(args)
|
||||
serialized_evaluation_keys = client.evaluation_keys.serialize()
|
||||
|
||||
unserialized_args = server.client_specs.unserialize_public_args(serialized_args)
|
||||
unserialized_evaluation_keys = EvaluationKeys.unserialize(serialized_evaluation_keys)
|
||||
|
||||
result = server.run(unserialized_args, unserialized_evaluation_keys)
|
||||
serialized_result = server.client_specs.serialize_public_result(result)
|
||||
|
||||
unserialized_result = client.specs.unserialize_public_result(serialized_result)
|
||||
output = client.decrypt(unserialized_result)
|
||||
|
||||
assert np.array_equal(output, [45, 50, 43])
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
server.save("UNUSED", via_mlir=True)
|
||||
|
||||
assert str(excinfo.value) == "Loaded server objects cannot be saved again via MLIR"
|
||||
|
||||
server.cleanup()
|
||||
|
||||
|
||||
def test_client_server_api_via_mlir(helpers):
|
||||
"""
|
||||
Test client/server API.
|
||||
"""
|
||||
|
||||
configuration = helpers.configuration()
|
||||
|
||||
@compiler({"x": "encrypted"})
|
||||
def function(x):
|
||||
return x + 42
|
||||
|
||||
inputset = [np.random.randint(0, 10, size=(3,)) for _ in range(10)]
|
||||
circuit = function.compile(inputset, configuration.fork(jit=False))
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_dir_path = Path(tmp_dir)
|
||||
|
||||
server_path = tmp_dir_path / "server.zip"
|
||||
circuit.server.save(server_path, via_mlir=True)
|
||||
|
||||
client_path = tmp_dir_path / "client.zip"
|
||||
circuit.client.save(client_path)
|
||||
|
||||
circuit.cleanup()
|
||||
|
||||
server = Server.load(server_path)
|
||||
|
||||
serialized_client_specs = server.client_specs.serialize()
|
||||
client_specs = ClientSpecs.unserialize(serialized_client_specs)
|
||||
|
||||
clients = [
|
||||
Client(client_specs, configuration.insecure_key_cache_location),
|
||||
Client.load(client_path, configuration.insecure_key_cache_location),
|
||||
]
|
||||
|
||||
for client in clients:
|
||||
args = client.encrypt([3, 8, 1])
|
||||
|
||||
|
||||
@@ -68,6 +68,12 @@ def test_configuration_fork():
|
||||
"Unexpected type for keyword argument 'dump_artifacts_on_unexpected_failures' "
|
||||
"(expected 'bool', got 'str')",
|
||||
),
|
||||
pytest.param(
|
||||
{"insecure_key_cache_location": 3},
|
||||
TypeError,
|
||||
"Unexpected type for keyword argument 'insecure_key_cache_location' "
|
||||
"(expected 'typing.Union[str, NoneType]', got 'int')",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_configuration_bad_fork(kwargs, expected_error, expected_message):
|
||||
|
||||
Reference in New Issue
Block a user