feat: allow saving servers via MLIR

This commit is contained in:
Umut
2022-08-23 16:20:38 +02:00
parent 7415dd07e1
commit 05282285a3
4 changed files with 142 additions and 8 deletions

View File

@@ -114,10 +114,20 @@ class Configuration:
raise TypeError(f"Unexpected keyword argument '{name}'")
hint = hints[name]
if not isinstance(value, hint): # type: ignore
is_correctly_typed = True
if name == "insecure_key_cache_location":
if not (value is None or isinstance(value, str)):
is_correctly_typed = False
elif not isinstance(value, hint): # type: ignore
is_correctly_typed = False
if not is_correctly_typed:
expected = hint.__name__ if hasattr(hint, "__name__") else str(hint)
raise TypeError(
f"Unexpected type for keyword argument '{name}' "
f"(expected '{hint.__name__}', got '{type(value).__name__}')"
f"(expected '{expected}', got '{type(value).__name__}')"
)
setattr(result, name, value)

View File

@@ -2,6 +2,7 @@
Declaration of `Server` class.
"""
import json
import shutil
import tempfile
from pathlib import Path
@@ -37,6 +38,9 @@ class Server:
_compilation_result: Union[JITCompilationResult, LibraryCompilationResult]
_server_lambda: Union[JITLambda, LibraryLambda]
_mlir: Optional[str]
_configuration: Optional[Configuration]
def __init__(
self,
client_specs: ClientSpecs,
@@ -51,6 +55,7 @@ class Server:
self._support = support
self._compilation_result = compilation_result
self._server_lambda = server_lambda
self._mlir = None
assert_that(
support.load_client_parameters(compilation_result).serialize()
@@ -112,27 +117,61 @@ class Server:
client_parameters = support.load_client_parameters(compilation_result)
client_specs = ClientSpecs(input_signs, client_parameters, output_signs)
return Server(client_specs, output_dir, support, compilation_result, server_lambda)
def save(self, path: Union[str, Path]):
result = Server(client_specs, output_dir, support, compilation_result, server_lambda)
# pylint: disable=protected-access
result._mlir = mlir
result._configuration = configuration
# pylint: enable=protected-access
return result
def save(self, path: Union[str, Path], via_mlir: bool = False):
"""
Save the server into the given path in zip format.
Args:
path (Union[str, Path]):
path to save the server
via_mlir (bool, default = False)
export using the MLIR code of the program,
this will make the export cross-platform
"""
path = str(path)
if path.endswith(".zip"):
path = path[: len(path) - 4]
if via_mlir:
if self._mlir is None or self._configuration is None:
raise RuntimeError("Loaded server objects cannot be saved again via MLIR")
with tempfile.TemporaryDirectory() as tmp:
with open(Path(tmp) / "circuit.mlir", "w", encoding="utf-8") as f:
f.write(self._mlir)
with open(Path(tmp) / "input_signs.json", "w", encoding="utf-8") as f:
f.write(json.dumps(self.client_specs.input_signs))
with open(Path(tmp) / "output_signs.json", "w", encoding="utf-8") as f:
f.write(json.dumps(self.client_specs.output_signs))
with open(Path(tmp) / "configuration.json", "w", encoding="utf-8") as f:
f.write(json.dumps(self._configuration.__dict__))
shutil.make_archive(path, "zip", tmp)
return
if self._output_dir is None:
raise RuntimeError("Just-in-Time compilation cannot be saved")
with open(Path(self._output_dir.name) / "client.specs.json", "w", encoding="utf-8") as f:
f.write(self.client_specs.serialize())
path = str(path)
if path.endswith(".zip"):
path = path[: len(path) - 4]
shutil.make_archive(path, "zip", self._output_dir.name)
@staticmethod
@@ -156,6 +195,25 @@ class Server:
shutil.unpack_archive(path, str(output_dir_path), "zip")
if (output_dir_path / "circuit.mlir").exists():
with open(output_dir_path / "circuit.mlir", "r", encoding="utf-8") as f:
mlir = f.read()
with open(output_dir_path / "input_signs.json", "r", encoding="utf-8") as f:
input_signs = json.load(f)
assert_that(isinstance(input_signs, list))
assert_that(all(isinstance(sign, bool) for sign in input_signs))
with open(output_dir_path / "output_signs.json", "r", encoding="utf-8") as f:
output_signs = json.load(f)
assert_that(isinstance(output_signs, list))
assert_that(all(isinstance(sign, bool) for sign in output_signs))
with open(output_dir_path / "configuration.json", "r", encoding="utf-8") as f:
configuration = Configuration().fork(**json.load(f))
return Server.create(mlir, input_signs, output_signs, configuration)
with open(output_dir_path / "client.specs.json", "r", encoding="utf-8") as f:
client_specs = ClientSpecs.unserialize(f.read())

View File

@@ -182,6 +182,66 @@ def test_client_server_api(helpers):
Client.load(client_path, configuration.insecure_key_cache_location),
]
for client in clients:
args = client.encrypt([3, 8, 1])
serialized_args = client.specs.serialize_public_args(args)
serialized_evaluation_keys = client.evaluation_keys.serialize()
unserialized_args = server.client_specs.unserialize_public_args(serialized_args)
unserialized_evaluation_keys = EvaluationKeys.unserialize(serialized_evaluation_keys)
result = server.run(unserialized_args, unserialized_evaluation_keys)
serialized_result = server.client_specs.serialize_public_result(result)
unserialized_result = client.specs.unserialize_public_result(serialized_result)
output = client.decrypt(unserialized_result)
assert np.array_equal(output, [45, 50, 43])
with pytest.raises(RuntimeError) as excinfo:
server.save("UNUSED", via_mlir=True)
assert str(excinfo.value) == "Loaded server objects cannot be saved again via MLIR"
server.cleanup()
def test_client_server_api_via_mlir(helpers):
"""
Test client/server API.
"""
configuration = helpers.configuration()
@compiler({"x": "encrypted"})
def function(x):
return x + 42
inputset = [np.random.randint(0, 10, size=(3,)) for _ in range(10)]
circuit = function.compile(inputset, configuration.fork(jit=False))
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_dir_path = Path(tmp_dir)
server_path = tmp_dir_path / "server.zip"
circuit.server.save(server_path, via_mlir=True)
client_path = tmp_dir_path / "client.zip"
circuit.client.save(client_path)
circuit.cleanup()
server = Server.load(server_path)
serialized_client_specs = server.client_specs.serialize()
client_specs = ClientSpecs.unserialize(serialized_client_specs)
clients = [
Client(client_specs, configuration.insecure_key_cache_location),
Client.load(client_path, configuration.insecure_key_cache_location),
]
for client in clients:
args = client.encrypt([3, 8, 1])

View File

@@ -68,6 +68,12 @@ def test_configuration_fork():
"Unexpected type for keyword argument 'dump_artifacts_on_unexpected_failures' "
"(expected 'bool', got 'str')",
),
pytest.param(
{"insecure_key_cache_location": 3},
TypeError,
"Unexpected type for keyword argument 'insecure_key_cache_location' "
"(expected 'typing.Union[str, NoneType]', got 'int')",
),
],
)
def test_configuration_bad_fork(kwargs, expected_error, expected_message):