diff --git a/concrete/numpy/compilation/artifacts.py b/concrete/numpy/compilation/artifacts.py index 5141efe7a..b99fab395 100644 --- a/concrete/numpy/compilation/artifacts.py +++ b/concrete/numpy/compilation/artifacts.py @@ -34,6 +34,8 @@ class DebugArtifacts: mlir_to_compile: Optional[str] + client_parameters: Optional[bytes] + def __init__(self, output_directory: Union[str, Path] = DEFAULT_OUTPUT_DIRECTORY): self.output_directory = Path(output_directory) @@ -48,6 +50,8 @@ class DebugArtifacts: self.mlir_to_compile = None + self.client_parameters = None + def add_source_code(self, function: Union[str, Callable]): """ Add source code of the function being compiled. @@ -127,6 +131,16 @@ class DebugArtifacts: self.mlir_to_compile = mlir + def add_client_parameters(self, client_parameters: bytes): + """ + Add client parameters used. + + Args: + client_parameters (bytes): client parameters + """ + + self.client_parameters = client_parameters + def export(self): """ Export the collected information to `self.output_directory`. @@ -205,3 +219,7 @@ class DebugArtifacts: assert self.final_graph is not None with open(output_directory.joinpath("mlir.txt"), "w", encoding="utf-8") as f: f.write(f"{self.mlir_to_compile}\n") + + if self.client_parameters is not None: + with open(output_directory.joinpath("client_parameters.json"), "wb") as f: + f.write(self.client_parameters) diff --git a/concrete/numpy/compilation/circuit.py b/concrete/numpy/compilation/circuit.py index 4de1c2d8b..e7452043e 100644 --- a/concrete/numpy/compilation/circuit.py +++ b/concrete/numpy/compilation/circuit.py @@ -43,6 +43,7 @@ class Circuit: graph: Graph mlir: str + client_parameters: Optional[ClientParameters] _support: Union[JITSupport, LibrarySupport] _compilation_result: Union[JITCompilationResult, LibraryCompilationResult] @@ -50,7 +51,6 @@ class Circuit: _output_dir: Optional[tempfile.TemporaryDirectory] - _client_parameters: ClientParameters _keyset: KeySet _keyset_cache: KeySetCache @@ -71,6 +71,7 @@ class Circuit: self.graph = graph self.mlir = mlir + self.client_parameters = None if configuration.virtual: assert_that(configuration.enable_unsafe_features) @@ -102,7 +103,7 @@ class Circuit: assert output_dir is not None assert_that(support.library_path == str(output_dir.name) + "/out") - client_parameters = support.load_client_parameters(compilation_result) + self.client_parameters = support.load_client_parameters(compilation_result) keyset = None keyset_cache = None @@ -112,7 +113,6 @@ class Circuit: if location is not None: keyset_cache = KeySetCache.new(str(location)) - self._client_parameters = client_parameters self._keyset = keyset self._keyset_cache = keyset_cache @@ -306,7 +306,7 @@ class Circuit: raise RuntimeError("Virtual circuits cannot use `keygen` method") if self._keyset is None or force: - self._keyset = ClientSupport.key_set(self._client_parameters, self._keyset_cache) + self._keyset = ClientSupport.key_set(self.client_parameters, self._keyset_cache) def encrypt(self, *args: Union[int, np.ndarray]) -> PublicArguments: """ @@ -365,7 +365,7 @@ class Circuit: self.keygen(force=False) return ClientSupport.encrypt_arguments( - self._client_parameters, + self.client_parameters, self._keyset, [sanitized_args[i] for i in range(len(sanitized_args))], ) @@ -420,14 +420,14 @@ class Circuit: expected_dtype = cast(Integer, expected_value.dtype) n = expected_dtype.bit_width - result = results[index] % (2 ** n) + result = results[index] % (2**n) if expected_dtype.is_signed: if isinstance(result, int): - sanititzed_result = result if result < (2 ** (n - 1)) else result - (2 ** n) + sanititzed_result = result if result < (2 ** (n - 1)) else result - (2**n) sanitized_results.append(sanititzed_result) else: result = result.astype(np.longlong) # to prevent overflows in numpy - sanititzed_result = np.where(result < (2 ** (n - 1)), result, result - (2 ** n)) + sanititzed_result = np.where(result < (2 ** (n - 1)), result, result - (2**n)) sanitized_results.append(sanititzed_result.astype(np.int8)) else: sanitized_results.append( diff --git a/concrete/numpy/compilation/compiler.py b/concrete/numpy/compilation/compiler.py index 6bac7d54e..7978686ee 100644 --- a/concrete/numpy/compilation/compiler.py +++ b/concrete/numpy/compilation/compiler.py @@ -375,7 +375,11 @@ class Compiler: print() - return Circuit.create(self.graph, mlir, self.configuration) + circuit = Circuit.create(self.graph, mlir, self.configuration) + if not self.configuration.virtual: + assert circuit.client_parameters is not None + self.artifacts.add_client_parameters(circuit.client_parameters.serialize()) + return circuit except Exception: # pragma: no cover diff --git a/tests/compilation/test_artifacts.py b/tests/compilation/test_artifacts.py index 35f6d035a..710749285 100644 --- a/tests/compilation/test_artifacts.py +++ b/tests/compilation/test_artifacts.py @@ -42,6 +42,7 @@ def test_artifacts_export(helpers): assert (tmpdir / "bounds.txt").exists() assert (tmpdir / "mlir.txt").exists() + assert (tmpdir / "client_parameters.json").exists() artifacts.export() @@ -59,3 +60,4 @@ def test_artifacts_export(helpers): assert (tmpdir / "bounds.txt").exists() assert (tmpdir / "mlir.txt").exists() + assert (tmpdir / "client_parameters.json").exists()