mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat: support library compilation and serialization
This commit is contained in:
@@ -2,6 +2,9 @@
|
||||
Declaration of `Circuit` class.
|
||||
"""
|
||||
|
||||
import pickle
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Optional, Tuple, Union, cast
|
||||
|
||||
@@ -15,6 +18,9 @@ from concrete.compiler import (
|
||||
JITSupport,
|
||||
KeySet,
|
||||
KeySetCache,
|
||||
LibraryCompilationResult,
|
||||
LibraryLambda,
|
||||
LibrarySupport,
|
||||
PublicArguments,
|
||||
PublicResult,
|
||||
)
|
||||
@@ -31,55 +37,227 @@ class Circuit:
|
||||
Circuit class, to combine computation graph and compiler engine into a single object.
|
||||
"""
|
||||
|
||||
# pylint: disable=too-many-instance-attributes
|
||||
|
||||
configuration: Configuration
|
||||
|
||||
graph: Graph
|
||||
mlir: str
|
||||
virtual: bool
|
||||
|
||||
_jit_support: JITSupport
|
||||
_compilation_result: JITCompilationResult
|
||||
_support: Union[JITSupport, LibrarySupport]
|
||||
_compilation_result: Union[JITCompilationResult, LibraryCompilationResult]
|
||||
_server_lambda: Union[JITLambda, LibraryLambda]
|
||||
|
||||
_output_dir: Optional[tempfile.TemporaryDirectory]
|
||||
|
||||
_client_parameters: ClientParameters
|
||||
|
||||
_keyset_cache: KeySetCache
|
||||
_keyset: KeySet
|
||||
_keyset_cache: KeySetCache
|
||||
|
||||
_server_lambda: JITLambda
|
||||
# pylint: enable=too-many-instance-attributes
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
configuration: Configuration,
|
||||
graph: Graph,
|
||||
mlir: str,
|
||||
configuration: Optional[Configuration] = None,
|
||||
support: Optional[Union[JITSupport, LibrarySupport]] = None,
|
||||
compilation_result: Optional[Union[JITCompilationResult, LibraryCompilationResult]] = None,
|
||||
server_lambda: Optional[Union[JITLambda, LibraryLambda]] = None,
|
||||
output_dir: Optional[tempfile.TemporaryDirectory] = None,
|
||||
):
|
||||
configuration = configuration if configuration is not None else Configuration()
|
||||
self.configuration = configuration
|
||||
self._output_dir = output_dir
|
||||
|
||||
self.graph = graph
|
||||
self.mlir = mlir
|
||||
|
||||
self.virtual = configuration.virtual
|
||||
if self.virtual:
|
||||
if configuration.virtual:
|
||||
assert_that(configuration.enable_unsafe_features)
|
||||
return
|
||||
|
||||
assert support is not None
|
||||
assert compilation_result is not None
|
||||
assert server_lambda is not None
|
||||
|
||||
assert_that(
|
||||
(
|
||||
isinstance(support, JITSupport)
|
||||
and isinstance(compilation_result, JITCompilationResult)
|
||||
and isinstance(server_lambda, JITLambda)
|
||||
)
|
||||
or (
|
||||
isinstance(support, LibrarySupport)
|
||||
and isinstance(compilation_result, LibraryCompilationResult)
|
||||
and isinstance(server_lambda, LibraryLambda)
|
||||
)
|
||||
)
|
||||
|
||||
self._support = support
|
||||
self._compilation_result = compilation_result
|
||||
self._server_lambda = server_lambda
|
||||
|
||||
self._output_dir = output_dir
|
||||
if isinstance(support, LibrarySupport):
|
||||
assert output_dir is not None
|
||||
assert_that(support.library_path == str(output_dir.name) + "/out")
|
||||
|
||||
client_parameters = support.load_client_parameters(compilation_result)
|
||||
keyset = None
|
||||
keyset_cache = None
|
||||
|
||||
if configuration.use_insecure_key_cache:
|
||||
assert_that(configuration.enable_unsafe_features)
|
||||
location = Configuration.insecure_key_cache_location()
|
||||
if location is not None:
|
||||
keyset_cache = KeySetCache.new(str(location))
|
||||
|
||||
self._client_parameters = client_parameters
|
||||
self._keyset = keyset
|
||||
self._keyset_cache = keyset_cache
|
||||
|
||||
@staticmethod
|
||||
def create(graph: Graph, mlir: str, configuration: Optional[Configuration] = None) -> "Circuit":
|
||||
"""
|
||||
Create a circuit from a graph and its MLIR.
|
||||
|
||||
Args:
|
||||
graph (Graph):
|
||||
graph of the circuit
|
||||
|
||||
mlir (str):
|
||||
mlir of the circuit
|
||||
|
||||
configuration (Optional[Configuration], default = None):
|
||||
configuration to use
|
||||
|
||||
Returns:
|
||||
Circuit:
|
||||
circuit of graph
|
||||
"""
|
||||
|
||||
configuration = configuration if configuration is not None else Configuration()
|
||||
if configuration.virtual:
|
||||
return Circuit(configuration, graph, mlir)
|
||||
|
||||
options = CompilationOptions.new("main")
|
||||
|
||||
options.set_loop_parallelize(configuration.loop_parallelize)
|
||||
options.set_dataflow_parallelize(configuration.dataflow_parallelize)
|
||||
options.set_auto_parallelize(configuration.auto_parallelize)
|
||||
|
||||
self._jit_support = JITSupport.new()
|
||||
self._compilation_result = self._jit_support.compile(mlir, options)
|
||||
if configuration.jit:
|
||||
|
||||
self._client_parameters = self._jit_support.load_client_parameters(self._compilation_result)
|
||||
output_dir = None
|
||||
|
||||
self._keyset_cache = None
|
||||
if configuration.use_insecure_key_cache:
|
||||
assert_that(configuration.enable_unsafe_features)
|
||||
location = Configuration.insecure_key_cache_location()
|
||||
if location is not None:
|
||||
self._keyset_cache = KeySetCache.new(str(location))
|
||||
self._keyset = None
|
||||
support = JITSupport.new()
|
||||
compilation_result = support.compile(mlir, options)
|
||||
server_lambda = support.load_server_lambda(compilation_result)
|
||||
|
||||
self._server_lambda = self._jit_support.load_server_lambda(self._compilation_result)
|
||||
else:
|
||||
|
||||
# pylint: disable=consider-using-with
|
||||
output_dir = tempfile.TemporaryDirectory()
|
||||
output_dir_path = Path(output_dir.name)
|
||||
# pylint: enable=consider-using-with
|
||||
|
||||
support = LibrarySupport.new(str(output_dir_path / "out"))
|
||||
compilation_result = support.compile(mlir, options)
|
||||
server_lambda = support.load_server_lambda(compilation_result)
|
||||
|
||||
return Circuit(
|
||||
configuration,
|
||||
graph,
|
||||
mlir,
|
||||
support,
|
||||
compilation_result,
|
||||
server_lambda,
|
||||
output_dir,
|
||||
)
|
||||
|
||||
def save(self, path: Union[str, Path]):
|
||||
"""
|
||||
Save the circuit into the given path in zip format.
|
||||
|
||||
Args:
|
||||
path (Union[str, Path]):
|
||||
path to save the circuit
|
||||
"""
|
||||
|
||||
if not self.configuration.virtual and self.configuration.jit:
|
||||
raise RuntimeError("JIT Circuits cannot be saved")
|
||||
|
||||
if self.configuration.virtual:
|
||||
# pylint: disable=consider-using-with
|
||||
self._output_dir = tempfile.TemporaryDirectory()
|
||||
# pylint: enable=consider-using-with
|
||||
|
||||
assert self._output_dir is not None
|
||||
output_dir_path = Path(self._output_dir.name)
|
||||
|
||||
with open(output_dir_path / "out.pickle", "wb") as f:
|
||||
attributes = {
|
||||
"configuration": self.configuration,
|
||||
"graph": self.graph,
|
||||
"mlir": self.mlir,
|
||||
}
|
||||
pickle.dump(attributes, f)
|
||||
|
||||
path = str(path)
|
||||
if path.endswith(".zip"):
|
||||
path = path[: len(path) - 4]
|
||||
|
||||
shutil.make_archive(path, "zip", str(output_dir_path))
|
||||
|
||||
if self.configuration.virtual:
|
||||
self.cleanup()
|
||||
self._output_dir = None
|
||||
|
||||
@staticmethod
|
||||
def load(path: Union[str, Path]) -> "Circuit":
|
||||
"""
|
||||
Load the circuit from the given path in zip format.
|
||||
|
||||
Args:
|
||||
path (Union[str, Path]):
|
||||
path to load the circuit from
|
||||
|
||||
Returns:
|
||||
Circuit:
|
||||
circuit loaded from the filesystem
|
||||
"""
|
||||
|
||||
# pylint: disable=consider-using-with
|
||||
output_dir = tempfile.TemporaryDirectory()
|
||||
output_dir_path = Path(output_dir.name)
|
||||
# pylint: enable=consider-using-with
|
||||
|
||||
shutil.unpack_archive(path, str(output_dir_path), "zip")
|
||||
|
||||
with open(output_dir_path / "out.pickle", "rb") as f:
|
||||
attributes = pickle.load(f)
|
||||
|
||||
configuration = attributes["configuration"]
|
||||
graph = attributes["graph"]
|
||||
mlir = attributes["mlir"]
|
||||
|
||||
if configuration.virtual:
|
||||
output_dir.cleanup()
|
||||
return Circuit(configuration, graph, mlir)
|
||||
|
||||
support = LibrarySupport.new(str(output_dir_path / "out"))
|
||||
compilation_result = support.reload("main")
|
||||
server_lambda = support.load_server_lambda(compilation_result)
|
||||
|
||||
return Circuit(
|
||||
configuration,
|
||||
graph,
|
||||
mlir,
|
||||
support,
|
||||
compilation_result,
|
||||
server_lambda,
|
||||
output_dir,
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
return self.graph.format()
|
||||
@@ -124,7 +302,7 @@ class Circuit:
|
||||
whether to generate new keys even if keys are already generated
|
||||
"""
|
||||
|
||||
if self.virtual:
|
||||
if self.configuration.virtual:
|
||||
raise RuntimeError("Virtual circuits cannot use `keygen` method")
|
||||
|
||||
if self._keyset is None or force:
|
||||
@@ -143,7 +321,7 @@ class Circuit:
|
||||
encrypted and plain arguments as well as public keys
|
||||
"""
|
||||
|
||||
if self.virtual:
|
||||
if self.configuration.virtual:
|
||||
raise RuntimeError("Virtual circuits cannot use `encrypt` method")
|
||||
|
||||
if len(args) != len(self.graph.input_nodes):
|
||||
@@ -205,10 +383,10 @@ class Circuit:
|
||||
encrypted result of homomorphic evaluaton
|
||||
"""
|
||||
|
||||
if self.virtual:
|
||||
if self.configuration.virtual:
|
||||
raise RuntimeError("Virtual circuits cannot use `run` method")
|
||||
|
||||
return self._jit_support.server_call(self._server_lambda, args)
|
||||
return self._support.server_call(self._server_lambda, args)
|
||||
|
||||
def decrypt(
|
||||
self,
|
||||
@@ -226,7 +404,7 @@ class Circuit:
|
||||
clear result of homomorphic evaluaton
|
||||
"""
|
||||
|
||||
if self.virtual:
|
||||
if self.configuration.virtual:
|
||||
raise RuntimeError("Virtual circuits cannot use `decrypt` method")
|
||||
|
||||
results = ClientSupport.decrypt_result(self._keyset, result)
|
||||
@@ -271,7 +449,15 @@ class Circuit:
|
||||
clear result of homomorphic evaluation
|
||||
"""
|
||||
|
||||
if self.virtual:
|
||||
if self.configuration.virtual:
|
||||
return self.graph(*args)
|
||||
|
||||
return self.decrypt(self.run(self.encrypt(*args)))
|
||||
|
||||
def cleanup(self):
|
||||
"""
|
||||
Cleanup the temporary library output directory.
|
||||
"""
|
||||
|
||||
if self._output_dir is not None:
|
||||
self._output_dir.cleanup()
|
||||
|
||||
@@ -375,7 +375,7 @@ class Compiler:
|
||||
|
||||
print()
|
||||
|
||||
return Circuit(self.graph, mlir, self.configuration)
|
||||
return Circuit.create(self.graph, mlir, self.configuration)
|
||||
|
||||
except Exception: # pragma: no cover
|
||||
|
||||
|
||||
@@ -13,6 +13,8 @@ class Configuration:
|
||||
Configuration class, to allow the compilation process to be customized.
|
||||
"""
|
||||
|
||||
# pylint: disable=too-many-instance-attributes
|
||||
|
||||
verbose: bool
|
||||
show_graph: bool
|
||||
show_mlir: bool
|
||||
@@ -23,6 +25,9 @@ class Configuration:
|
||||
loop_parallelize: bool
|
||||
dataflow_parallelize: bool
|
||||
auto_parallelize: bool
|
||||
jit: bool
|
||||
|
||||
# pylint: enable=too-many-instance-attributes
|
||||
|
||||
def _validate(self):
|
||||
"""
|
||||
@@ -55,6 +60,7 @@ class Configuration:
|
||||
loop_parallelize: bool = True,
|
||||
dataflow_parallelize: bool = False,
|
||||
auto_parallelize: bool = False,
|
||||
jit: bool = False,
|
||||
):
|
||||
self.verbose = verbose
|
||||
self.show_graph = show_graph
|
||||
@@ -66,6 +72,7 @@ class Configuration:
|
||||
self.loop_parallelize = loop_parallelize
|
||||
self.dataflow_parallelize = dataflow_parallelize
|
||||
self.auto_parallelize = auto_parallelize
|
||||
self.jit = jit
|
||||
|
||||
self._validate()
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from pathlib import Path
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from concrete.numpy import Circuit
|
||||
from concrete.numpy.compilation import compiler
|
||||
|
||||
|
||||
@@ -164,3 +165,86 @@ def test_circuit_virtual_explicit_api(helpers):
|
||||
circuit.decrypt(None)
|
||||
|
||||
assert str(excinfo.value) == "Virtual circuits cannot use `decrypt` method"
|
||||
|
||||
|
||||
def test_circuit_bad_save(helpers):
|
||||
"""
|
||||
Test `save` method of `Circuit` class with bad parameters.
|
||||
"""
|
||||
|
||||
configuration = helpers.configuration()
|
||||
|
||||
@compiler({"x": "encrypted"})
|
||||
def function(x):
|
||||
return x + 42
|
||||
|
||||
inputset = range(10)
|
||||
circuit = function.compile(inputset, configuration)
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
circuit.save("circuit.zip")
|
||||
|
||||
assert str(excinfo.value) == "JIT Circuits cannot be saved"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"virtual",
|
||||
[False, True],
|
||||
)
|
||||
def test_circuit_save_load(virtual, helpers):
|
||||
"""
|
||||
Test `save`, `load`, and `cleanup` methods of `Circuit` class.
|
||||
"""
|
||||
|
||||
configuration = helpers.configuration().fork(jit=False, virtual=virtual)
|
||||
|
||||
def save(base):
|
||||
@compiler({"x": "encrypted"})
|
||||
def function(x):
|
||||
return x + 42
|
||||
|
||||
inputset = range(10)
|
||||
circuit = function.compile(inputset, configuration)
|
||||
|
||||
circuit.save(base / "circuit.zip")
|
||||
circuit.cleanup()
|
||||
|
||||
def load(base):
|
||||
circuit = Circuit.load(base / "circuit.zip")
|
||||
|
||||
helpers.check_str(
|
||||
"""
|
||||
|
||||
%0 = x # EncryptedScalar<uint4>
|
||||
%1 = 42 # ClearScalar<uint6>
|
||||
%2 = add(%0, %1) # EncryptedScalar<uint6>
|
||||
return %2
|
||||
|
||||
""",
|
||||
str(circuit),
|
||||
)
|
||||
if virtual:
|
||||
helpers.check_str("Virtual circuits doesn't have MLIR.", circuit.mlir)
|
||||
else:
|
||||
helpers.check_str(
|
||||
"""
|
||||
|
||||
module {
|
||||
func @main(%arg0: !FHE.eint<6>) -> !FHE.eint<6> {
|
||||
%c42_i7 = arith.constant 42 : i7
|
||||
%0 = "FHE.add_eint_int"(%arg0, %c42_i7) : (!FHE.eint<6>, i7) -> !FHE.eint<6>
|
||||
return %0 : !FHE.eint<6>
|
||||
}
|
||||
}
|
||||
|
||||
""",
|
||||
circuit.mlir,
|
||||
)
|
||||
helpers.check_execution(circuit, lambda x: x + 42, 4)
|
||||
|
||||
circuit.cleanup()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = Path(tmp)
|
||||
save(path)
|
||||
load(path)
|
||||
|
||||
@@ -107,6 +107,7 @@ class Helpers:
|
||||
loop_parallelize=True,
|
||||
dataflow_parallelize=False,
|
||||
auto_parallelize=False,
|
||||
jit=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
Reference in New Issue
Block a user