feat: add client parameters to debug artifacts

This commit is contained in:
youben11
2022-05-09 16:50:48 +01:00
committed by Ayoub Benaissa
parent e90a9f1a55
commit 9bd587695f
4 changed files with 33 additions and 9 deletions

View File

@@ -34,6 +34,8 @@ class DebugArtifacts:
mlir_to_compile: Optional[str]
client_parameters: Optional[bytes]
def __init__(self, output_directory: Union[str, Path] = DEFAULT_OUTPUT_DIRECTORY):
self.output_directory = Path(output_directory)
@@ -48,6 +50,8 @@ class DebugArtifacts:
self.mlir_to_compile = None
self.client_parameters = None
def add_source_code(self, function: Union[str, Callable]):
"""
Add source code of the function being compiled.
@@ -127,6 +131,16 @@ class DebugArtifacts:
self.mlir_to_compile = mlir
def add_client_parameters(self, client_parameters: bytes):
"""
Add client parameters used.
Args:
client_parameters (bytes): client parameters
"""
self.client_parameters = client_parameters
def export(self):
"""
Export the collected information to `self.output_directory`.
@@ -205,3 +219,7 @@ class DebugArtifacts:
assert self.final_graph is not None
with open(output_directory.joinpath("mlir.txt"), "w", encoding="utf-8") as f:
f.write(f"{self.mlir_to_compile}\n")
if self.client_parameters is not None:
with open(output_directory.joinpath("client_parameters.json"), "wb") as f:
f.write(self.client_parameters)

View File

@@ -43,6 +43,7 @@ class Circuit:
graph: Graph
mlir: str
client_parameters: Optional[ClientParameters]
_support: Union[JITSupport, LibrarySupport]
_compilation_result: Union[JITCompilationResult, LibraryCompilationResult]
@@ -50,7 +51,6 @@ class Circuit:
_output_dir: Optional[tempfile.TemporaryDirectory]
_client_parameters: ClientParameters
_keyset: KeySet
_keyset_cache: KeySetCache
@@ -71,6 +71,7 @@ class Circuit:
self.graph = graph
self.mlir = mlir
self.client_parameters = None
if configuration.virtual:
assert_that(configuration.enable_unsafe_features)
@@ -102,7 +103,7 @@ class Circuit:
assert output_dir is not None
assert_that(support.library_path == str(output_dir.name) + "/out")
client_parameters = support.load_client_parameters(compilation_result)
self.client_parameters = support.load_client_parameters(compilation_result)
keyset = None
keyset_cache = None
@@ -112,7 +113,6 @@ class Circuit:
if location is not None:
keyset_cache = KeySetCache.new(str(location))
self._client_parameters = client_parameters
self._keyset = keyset
self._keyset_cache = keyset_cache
@@ -306,7 +306,7 @@ class Circuit:
raise RuntimeError("Virtual circuits cannot use `keygen` method")
if self._keyset is None or force:
self._keyset = ClientSupport.key_set(self._client_parameters, self._keyset_cache)
self._keyset = ClientSupport.key_set(self.client_parameters, self._keyset_cache)
def encrypt(self, *args: Union[int, np.ndarray]) -> PublicArguments:
"""
@@ -365,7 +365,7 @@ class Circuit:
self.keygen(force=False)
return ClientSupport.encrypt_arguments(
self._client_parameters,
self.client_parameters,
self._keyset,
[sanitized_args[i] for i in range(len(sanitized_args))],
)
@@ -420,14 +420,14 @@ class Circuit:
expected_dtype = cast(Integer, expected_value.dtype)
n = expected_dtype.bit_width
result = results[index] % (2 ** n)
result = results[index] % (2**n)
if expected_dtype.is_signed:
if isinstance(result, int):
sanititzed_result = result if result < (2 ** (n - 1)) else result - (2 ** n)
sanititzed_result = result if result < (2 ** (n - 1)) else result - (2**n)
sanitized_results.append(sanititzed_result)
else:
result = result.astype(np.longlong) # to prevent overflows in numpy
sanititzed_result = np.where(result < (2 ** (n - 1)), result, result - (2 ** n))
sanititzed_result = np.where(result < (2 ** (n - 1)), result, result - (2**n))
sanitized_results.append(sanititzed_result.astype(np.int8))
else:
sanitized_results.append(

View File

@@ -375,7 +375,11 @@ class Compiler:
print()
return Circuit.create(self.graph, mlir, self.configuration)
circuit = Circuit.create(self.graph, mlir, self.configuration)
if not self.configuration.virtual:
assert circuit.client_parameters is not None
self.artifacts.add_client_parameters(circuit.client_parameters.serialize())
return circuit
except Exception: # pragma: no cover

View File

@@ -42,6 +42,7 @@ def test_artifacts_export(helpers):
assert (tmpdir / "bounds.txt").exists()
assert (tmpdir / "mlir.txt").exists()
assert (tmpdir / "client_parameters.json").exists()
artifacts.export()
@@ -59,3 +60,4 @@ def test_artifacts_export(helpers):
assert (tmpdir / "bounds.txt").exists()
assert (tmpdir / "mlir.txt").exists()
assert (tmpdir / "client_parameters.json").exists()