mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat: add client parameters to debug artifacts
This commit is contained in:
@@ -34,6 +34,8 @@ class DebugArtifacts:
|
||||
|
||||
mlir_to_compile: Optional[str]
|
||||
|
||||
client_parameters: Optional[bytes]
|
||||
|
||||
def __init__(self, output_directory: Union[str, Path] = DEFAULT_OUTPUT_DIRECTORY):
|
||||
self.output_directory = Path(output_directory)
|
||||
|
||||
@@ -48,6 +50,8 @@ class DebugArtifacts:
|
||||
|
||||
self.mlir_to_compile = None
|
||||
|
||||
self.client_parameters = None
|
||||
|
||||
def add_source_code(self, function: Union[str, Callable]):
|
||||
"""
|
||||
Add source code of the function being compiled.
|
||||
@@ -127,6 +131,16 @@ class DebugArtifacts:
|
||||
|
||||
self.mlir_to_compile = mlir
|
||||
|
||||
def add_client_parameters(self, client_parameters: bytes):
|
||||
"""
|
||||
Add client parameters used.
|
||||
|
||||
Args:
|
||||
client_parameters (bytes): client parameters
|
||||
"""
|
||||
|
||||
self.client_parameters = client_parameters
|
||||
|
||||
def export(self):
|
||||
"""
|
||||
Export the collected information to `self.output_directory`.
|
||||
@@ -205,3 +219,7 @@ class DebugArtifacts:
|
||||
assert self.final_graph is not None
|
||||
with open(output_directory.joinpath("mlir.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(f"{self.mlir_to_compile}\n")
|
||||
|
||||
if self.client_parameters is not None:
|
||||
with open(output_directory.joinpath("client_parameters.json"), "wb") as f:
|
||||
f.write(self.client_parameters)
|
||||
|
||||
@@ -43,6 +43,7 @@ class Circuit:
|
||||
|
||||
graph: Graph
|
||||
mlir: str
|
||||
client_parameters: Optional[ClientParameters]
|
||||
|
||||
_support: Union[JITSupport, LibrarySupport]
|
||||
_compilation_result: Union[JITCompilationResult, LibraryCompilationResult]
|
||||
@@ -50,7 +51,6 @@ class Circuit:
|
||||
|
||||
_output_dir: Optional[tempfile.TemporaryDirectory]
|
||||
|
||||
_client_parameters: ClientParameters
|
||||
_keyset: KeySet
|
||||
_keyset_cache: KeySetCache
|
||||
|
||||
@@ -71,6 +71,7 @@ class Circuit:
|
||||
|
||||
self.graph = graph
|
||||
self.mlir = mlir
|
||||
self.client_parameters = None
|
||||
|
||||
if configuration.virtual:
|
||||
assert_that(configuration.enable_unsafe_features)
|
||||
@@ -102,7 +103,7 @@ class Circuit:
|
||||
assert output_dir is not None
|
||||
assert_that(support.library_path == str(output_dir.name) + "/out")
|
||||
|
||||
client_parameters = support.load_client_parameters(compilation_result)
|
||||
self.client_parameters = support.load_client_parameters(compilation_result)
|
||||
keyset = None
|
||||
keyset_cache = None
|
||||
|
||||
@@ -112,7 +113,6 @@ class Circuit:
|
||||
if location is not None:
|
||||
keyset_cache = KeySetCache.new(str(location))
|
||||
|
||||
self._client_parameters = client_parameters
|
||||
self._keyset = keyset
|
||||
self._keyset_cache = keyset_cache
|
||||
|
||||
@@ -306,7 +306,7 @@ class Circuit:
|
||||
raise RuntimeError("Virtual circuits cannot use `keygen` method")
|
||||
|
||||
if self._keyset is None or force:
|
||||
self._keyset = ClientSupport.key_set(self._client_parameters, self._keyset_cache)
|
||||
self._keyset = ClientSupport.key_set(self.client_parameters, self._keyset_cache)
|
||||
|
||||
def encrypt(self, *args: Union[int, np.ndarray]) -> PublicArguments:
|
||||
"""
|
||||
@@ -365,7 +365,7 @@ class Circuit:
|
||||
|
||||
self.keygen(force=False)
|
||||
return ClientSupport.encrypt_arguments(
|
||||
self._client_parameters,
|
||||
self.client_parameters,
|
||||
self._keyset,
|
||||
[sanitized_args[i] for i in range(len(sanitized_args))],
|
||||
)
|
||||
@@ -420,14 +420,14 @@ class Circuit:
|
||||
expected_dtype = cast(Integer, expected_value.dtype)
|
||||
n = expected_dtype.bit_width
|
||||
|
||||
result = results[index] % (2 ** n)
|
||||
result = results[index] % (2**n)
|
||||
if expected_dtype.is_signed:
|
||||
if isinstance(result, int):
|
||||
sanititzed_result = result if result < (2 ** (n - 1)) else result - (2 ** n)
|
||||
sanititzed_result = result if result < (2 ** (n - 1)) else result - (2**n)
|
||||
sanitized_results.append(sanititzed_result)
|
||||
else:
|
||||
result = result.astype(np.longlong) # to prevent overflows in numpy
|
||||
sanititzed_result = np.where(result < (2 ** (n - 1)), result, result - (2 ** n))
|
||||
sanititzed_result = np.where(result < (2 ** (n - 1)), result, result - (2**n))
|
||||
sanitized_results.append(sanititzed_result.astype(np.int8))
|
||||
else:
|
||||
sanitized_results.append(
|
||||
|
||||
@@ -375,7 +375,11 @@ class Compiler:
|
||||
|
||||
print()
|
||||
|
||||
return Circuit.create(self.graph, mlir, self.configuration)
|
||||
circuit = Circuit.create(self.graph, mlir, self.configuration)
|
||||
if not self.configuration.virtual:
|
||||
assert circuit.client_parameters is not None
|
||||
self.artifacts.add_client_parameters(circuit.client_parameters.serialize())
|
||||
return circuit
|
||||
|
||||
except Exception: # pragma: no cover
|
||||
|
||||
|
||||
@@ -42,6 +42,7 @@ def test_artifacts_export(helpers):
|
||||
|
||||
assert (tmpdir / "bounds.txt").exists()
|
||||
assert (tmpdir / "mlir.txt").exists()
|
||||
assert (tmpdir / "client_parameters.json").exists()
|
||||
|
||||
artifacts.export()
|
||||
|
||||
@@ -59,3 +60,4 @@ def test_artifacts_export(helpers):
|
||||
|
||||
assert (tmpdir / "bounds.txt").exists()
|
||||
assert (tmpdir / "mlir.txt").exists()
|
||||
assert (tmpdir / "client_parameters.json").exists()
|
||||
|
||||
Reference in New Issue
Block a user