feat: support library compilation and serialization

This commit is contained in:
Umut
2022-05-05 16:45:11 +02:00
parent 6739e2e8ab
commit 6662b71dfe
5 changed files with 306 additions and 28 deletions

View File

@@ -2,6 +2,9 @@
Declaration of `Circuit` class.
"""
import pickle
import shutil
import tempfile
from pathlib import Path
from typing import Any, List, Optional, Tuple, Union, cast
@@ -15,6 +18,9 @@ from concrete.compiler import (
JITSupport,
KeySet,
KeySetCache,
LibraryCompilationResult,
LibraryLambda,
LibrarySupport,
PublicArguments,
PublicResult,
)
@@ -31,55 +37,227 @@ class Circuit:
Circuit class, to combine computation graph and compiler engine into a single object.
"""
# pylint: disable=too-many-instance-attributes
configuration: Configuration
graph: Graph
mlir: str
virtual: bool
_jit_support: JITSupport
_compilation_result: JITCompilationResult
_support: Union[JITSupport, LibrarySupport]
_compilation_result: Union[JITCompilationResult, LibraryCompilationResult]
_server_lambda: Union[JITLambda, LibraryLambda]
_output_dir: Optional[tempfile.TemporaryDirectory]
_client_parameters: ClientParameters
_keyset_cache: KeySetCache
_keyset: KeySet
_keyset_cache: KeySetCache
_server_lambda: JITLambda
# pylint: enable=too-many-instance-attributes
def __init__(
self,
configuration: Configuration,
graph: Graph,
mlir: str,
configuration: Optional[Configuration] = None,
support: Optional[Union[JITSupport, LibrarySupport]] = None,
compilation_result: Optional[Union[JITCompilationResult, LibraryCompilationResult]] = None,
server_lambda: Optional[Union[JITLambda, LibraryLambda]] = None,
output_dir: Optional[tempfile.TemporaryDirectory] = None,
):
configuration = configuration if configuration is not None else Configuration()
self.configuration = configuration
self._output_dir = output_dir
self.graph = graph
self.mlir = mlir
self.virtual = configuration.virtual
if self.virtual:
if configuration.virtual:
assert_that(configuration.enable_unsafe_features)
return
assert support is not None
assert compilation_result is not None
assert server_lambda is not None
assert_that(
(
isinstance(support, JITSupport)
and isinstance(compilation_result, JITCompilationResult)
and isinstance(server_lambda, JITLambda)
)
or (
isinstance(support, LibrarySupport)
and isinstance(compilation_result, LibraryCompilationResult)
and isinstance(server_lambda, LibraryLambda)
)
)
self._support = support
self._compilation_result = compilation_result
self._server_lambda = server_lambda
self._output_dir = output_dir
if isinstance(support, LibrarySupport):
assert output_dir is not None
assert_that(support.library_path == str(output_dir.name) + "/out")
client_parameters = support.load_client_parameters(compilation_result)
keyset = None
keyset_cache = None
if configuration.use_insecure_key_cache:
assert_that(configuration.enable_unsafe_features)
location = Configuration.insecure_key_cache_location()
if location is not None:
keyset_cache = KeySetCache.new(str(location))
self._client_parameters = client_parameters
self._keyset = keyset
self._keyset_cache = keyset_cache
@staticmethod
def create(graph: Graph, mlir: str, configuration: Optional[Configuration] = None) -> "Circuit":
"""
Create a circuit from a graph and its MLIR.
Args:
graph (Graph):
graph of the circuit
mlir (str):
mlir of the circuit
configuration (Optional[Configuration], default = None):
configuration to use
Returns:
Circuit:
circuit of graph
"""
configuration = configuration if configuration is not None else Configuration()
if configuration.virtual:
return Circuit(configuration, graph, mlir)
options = CompilationOptions.new("main")
options.set_loop_parallelize(configuration.loop_parallelize)
options.set_dataflow_parallelize(configuration.dataflow_parallelize)
options.set_auto_parallelize(configuration.auto_parallelize)
self._jit_support = JITSupport.new()
self._compilation_result = self._jit_support.compile(mlir, options)
if configuration.jit:
self._client_parameters = self._jit_support.load_client_parameters(self._compilation_result)
output_dir = None
self._keyset_cache = None
if configuration.use_insecure_key_cache:
assert_that(configuration.enable_unsafe_features)
location = Configuration.insecure_key_cache_location()
if location is not None:
self._keyset_cache = KeySetCache.new(str(location))
self._keyset = None
support = JITSupport.new()
compilation_result = support.compile(mlir, options)
server_lambda = support.load_server_lambda(compilation_result)
self._server_lambda = self._jit_support.load_server_lambda(self._compilation_result)
else:
# pylint: disable=consider-using-with
output_dir = tempfile.TemporaryDirectory()
output_dir_path = Path(output_dir.name)
# pylint: enable=consider-using-with
support = LibrarySupport.new(str(output_dir_path / "out"))
compilation_result = support.compile(mlir, options)
server_lambda = support.load_server_lambda(compilation_result)
return Circuit(
configuration,
graph,
mlir,
support,
compilation_result,
server_lambda,
output_dir,
)
def save(self, path: Union[str, Path]):
"""
Save the circuit into the given path in zip format.
Args:
path (Union[str, Path]):
path to save the circuit
"""
if not self.configuration.virtual and self.configuration.jit:
raise RuntimeError("JIT Circuits cannot be saved")
if self.configuration.virtual:
# pylint: disable=consider-using-with
self._output_dir = tempfile.TemporaryDirectory()
# pylint: enable=consider-using-with
assert self._output_dir is not None
output_dir_path = Path(self._output_dir.name)
with open(output_dir_path / "out.pickle", "wb") as f:
attributes = {
"configuration": self.configuration,
"graph": self.graph,
"mlir": self.mlir,
}
pickle.dump(attributes, f)
path = str(path)
if path.endswith(".zip"):
path = path[: len(path) - 4]
shutil.make_archive(path, "zip", str(output_dir_path))
if self.configuration.virtual:
self.cleanup()
self._output_dir = None
@staticmethod
def load(path: Union[str, Path]) -> "Circuit":
"""
Load the circuit from the given path in zip format.
Args:
path (Union[str, Path]):
path to load the circuit from
Returns:
Circuit:
circuit loaded from the filesystem
"""
# pylint: disable=consider-using-with
output_dir = tempfile.TemporaryDirectory()
output_dir_path = Path(output_dir.name)
# pylint: enable=consider-using-with
shutil.unpack_archive(path, str(output_dir_path), "zip")
with open(output_dir_path / "out.pickle", "rb") as f:
attributes = pickle.load(f)
configuration = attributes["configuration"]
graph = attributes["graph"]
mlir = attributes["mlir"]
if configuration.virtual:
output_dir.cleanup()
return Circuit(configuration, graph, mlir)
support = LibrarySupport.new(str(output_dir_path / "out"))
compilation_result = support.reload("main")
server_lambda = support.load_server_lambda(compilation_result)
return Circuit(
configuration,
graph,
mlir,
support,
compilation_result,
server_lambda,
output_dir,
)
def __str__(self):
return self.graph.format()
@@ -124,7 +302,7 @@ class Circuit:
whether to generate new keys even if keys are already generated
"""
if self.virtual:
if self.configuration.virtual:
raise RuntimeError("Virtual circuits cannot use `keygen` method")
if self._keyset is None or force:
@@ -143,7 +321,7 @@ class Circuit:
encrypted and plain arguments as well as public keys
"""
if self.virtual:
if self.configuration.virtual:
raise RuntimeError("Virtual circuits cannot use `encrypt` method")
if len(args) != len(self.graph.input_nodes):
@@ -205,10 +383,10 @@ class Circuit:
encrypted result of homomorphic evaluaton
"""
if self.virtual:
if self.configuration.virtual:
raise RuntimeError("Virtual circuits cannot use `run` method")
return self._jit_support.server_call(self._server_lambda, args)
return self._support.server_call(self._server_lambda, args)
def decrypt(
self,
@@ -226,7 +404,7 @@ class Circuit:
clear result of homomorphic evaluaton
"""
if self.virtual:
if self.configuration.virtual:
raise RuntimeError("Virtual circuits cannot use `decrypt` method")
results = ClientSupport.decrypt_result(self._keyset, result)
@@ -271,7 +449,15 @@ class Circuit:
clear result of homomorphic evaluation
"""
if self.virtual:
if self.configuration.virtual:
return self.graph(*args)
return self.decrypt(self.run(self.encrypt(*args)))
def cleanup(self):
"""
Cleanup the temporary library output directory.
"""
if self._output_dir is not None:
self._output_dir.cleanup()

View File

@@ -375,7 +375,7 @@ class Compiler:
print()
return Circuit(self.graph, mlir, self.configuration)
return Circuit.create(self.graph, mlir, self.configuration)
except Exception: # pragma: no cover

View File

@@ -13,6 +13,8 @@ class Configuration:
Configuration class, to allow the compilation process to be customized.
"""
# pylint: disable=too-many-instance-attributes
verbose: bool
show_graph: bool
show_mlir: bool
@@ -23,6 +25,9 @@ class Configuration:
loop_parallelize: bool
dataflow_parallelize: bool
auto_parallelize: bool
jit: bool
# pylint: enable=too-many-instance-attributes
def _validate(self):
"""
@@ -55,6 +60,7 @@ class Configuration:
loop_parallelize: bool = True,
dataflow_parallelize: bool = False,
auto_parallelize: bool = False,
jit: bool = False,
):
self.verbose = verbose
self.show_graph = show_graph
@@ -66,6 +72,7 @@ class Configuration:
self.loop_parallelize = loop_parallelize
self.dataflow_parallelize = dataflow_parallelize
self.auto_parallelize = auto_parallelize
self.jit = jit
self._validate()

View File

@@ -8,6 +8,7 @@ from pathlib import Path
import numpy as np
import pytest
from concrete.numpy import Circuit
from concrete.numpy.compilation import compiler
@@ -164,3 +165,86 @@ def test_circuit_virtual_explicit_api(helpers):
circuit.decrypt(None)
assert str(excinfo.value) == "Virtual circuits cannot use `decrypt` method"
def test_circuit_bad_save(helpers):
"""
Test `save` method of `Circuit` class with bad parameters.
"""
configuration = helpers.configuration()
@compiler({"x": "encrypted"})
def function(x):
return x + 42
inputset = range(10)
circuit = function.compile(inputset, configuration)
with pytest.raises(RuntimeError) as excinfo:
circuit.save("circuit.zip")
assert str(excinfo.value) == "JIT Circuits cannot be saved"
@pytest.mark.parametrize(
"virtual",
[False, True],
)
def test_circuit_save_load(virtual, helpers):
"""
Test `save`, `load`, and `cleanup` methods of `Circuit` class.
"""
configuration = helpers.configuration().fork(jit=False, virtual=virtual)
def save(base):
@compiler({"x": "encrypted"})
def function(x):
return x + 42
inputset = range(10)
circuit = function.compile(inputset, configuration)
circuit.save(base / "circuit.zip")
circuit.cleanup()
def load(base):
circuit = Circuit.load(base / "circuit.zip")
helpers.check_str(
"""
%0 = x # EncryptedScalar<uint4>
%1 = 42 # ClearScalar<uint6>
%2 = add(%0, %1) # EncryptedScalar<uint6>
return %2
""",
str(circuit),
)
if virtual:
helpers.check_str("Virtual circuits doesn't have MLIR.", circuit.mlir)
else:
helpers.check_str(
"""
module {
func @main(%arg0: !FHE.eint<6>) -> !FHE.eint<6> {
%c42_i7 = arith.constant 42 : i7
%0 = "FHE.add_eint_int"(%arg0, %c42_i7) : (!FHE.eint<6>, i7) -> !FHE.eint<6>
return %0 : !FHE.eint<6>
}
}
""",
circuit.mlir,
)
helpers.check_execution(circuit, lambda x: x + 42, 4)
circuit.cleanup()
with tempfile.TemporaryDirectory() as tmp:
path = Path(tmp)
save(path)
load(path)

View File

@@ -107,6 +107,7 @@ class Helpers:
loop_parallelize=True,
dataflow_parallelize=False,
auto_parallelize=False,
jit=True,
)
@staticmethod