fix: make artifacts truly optional to improve performance for large graphs

This commit is contained in:
Umut
2022-06-22 15:14:15 +02:00
parent 1e86c3b1e4
commit 1cc5b576eb
5 changed files with 79 additions and 21 deletions

View File

@@ -40,7 +40,7 @@ class Compiler:
parameter_encryption_statuses: Dict[str, EncryptionStatus]
configuration: Configuration
artifacts: DebugArtifacts
artifacts: Optional[DebugArtifacts]
inputset: List[Any]
graph: Optional[Graph]
@@ -97,7 +97,7 @@ class Compiler:
}
self.configuration = Configuration()
self.artifacts = DebugArtifacts()
self.artifacts = None
self.inputset = []
self.graph = None
@@ -136,9 +136,10 @@ class Compiler:
sample to use for tracing
"""
self.artifacts.add_source_code(self.function)
for param, encryption_status in self.parameter_encryption_statuses.items():
self.artifacts.add_parameter_encryption_status(param, encryption_status)
if self.artifacts is not None:
self.artifacts.add_source_code(self.function)
for param, encryption_status in self.parameter_encryption_statuses.items():
self.artifacts.add_parameter_encryption_status(param, encryption_status)
parameters = {
param: Value.of(arg, is_encrypted=(status == EncryptionStatus.ENCRYPTED))
@@ -149,7 +150,8 @@ class Compiler:
}
self.graph = Tracer.trace(self.function, parameters)
self.artifacts.add_graph("initial", self.graph)
if self.artifacts is not None:
self.artifacts.add_graph("initial", self.graph)
fuse(self.graph, self.artifacts)
@@ -205,10 +207,12 @@ class Compiler:
assert self.graph is not None
bounds = self.graph.measure_bounds(self.inputset)
self.artifacts.add_final_graph_bounds(bounds)
if self.artifacts is not None:
self.artifacts.add_final_graph_bounds(bounds)
self.graph.update_with_bounds(bounds)
self.artifacts.add_graph("final", self.graph)
if self.artifacts is not None:
self.artifacts.add_graph("final", self.graph)
def trace(
self,
@@ -243,12 +247,18 @@ class Compiler:
if configuration is not None:
self.configuration = configuration
if artifacts is not None:
self.artifacts = artifacts
if len(kwargs) != 0:
self.configuration = self.configuration.fork(**kwargs)
self.artifacts = (
artifacts
if artifacts is not None
else DebugArtifacts()
if self.configuration.dump_artifacts_on_unexpected_failures
else None
)
try:
self._evaluate("Tracing", inputset)
@@ -292,6 +302,7 @@ class Compiler:
# we need to export all the information we have about the compilation
if self.configuration.dump_artifacts_on_unexpected_failures:
assert self.artifacts is not None
self.artifacts.export()
traceback_path = self.artifacts.output_directory.joinpath("traceback.txt")
@@ -340,19 +351,26 @@ class Compiler:
if configuration is not None:
self.configuration = configuration
if artifacts is not None:
self.artifacts = artifacts
if len(kwargs) != 0:
self.configuration = self.configuration.fork(**kwargs)
self.artifacts = (
artifacts
if artifacts is not None
else DebugArtifacts()
if self.configuration.dump_artifacts_on_unexpected_failures
else None
)
try:
self._evaluate("Compiling", inputset)
assert self.graph is not None
mlir = GraphConverter.convert(self.graph, virtual=self.configuration.virtual)
self.artifacts.add_mlir_to_compile(mlir)
if self.artifacts is not None:
self.artifacts.add_mlir_to_compile(mlir)
if (
self.configuration.verbose
@@ -412,9 +430,10 @@ class Compiler:
circuit = Circuit(self.graph, mlir, self.configuration)
if not self.configuration.virtual:
assert circuit.client.specs.client_parameters is not None
self.artifacts.add_client_parameters(
circuit.client.specs.client_parameters.serialize()
)
if self.artifacts is not None:
self.artifacts.add_client_parameters(
circuit.client.specs.client_parameters.serialize()
)
return circuit
except Exception: # pragma: no cover
@@ -426,6 +445,7 @@ class Compiler:
# we need to export all the information we have about the compilation
if self.configuration.dump_artifacts_on_unexpected_failures:
assert self.artifacts is not None
self.artifacts.export()
traceback_path = self.artifacts.output_directory.joinpath("traceback.txt")

View File

@@ -286,11 +286,7 @@ class Node:
assert_that(self.operation == Operation.Generic)
name = self.properties["name"]
if name == "index.static":
name = self.format(["index"])
return name
return name if name != "index.static" else self.format(["index"])
@property
def converted_to_table_lookup(self) -> bool:

View File

@@ -0,0 +1,3 @@
"""
Tests of extensions.
"""

View File

@@ -0,0 +1,28 @@
"""
Tests of LookupTable.
"""
import pytest
from concrete.numpy import LookupTable
@pytest.mark.parametrize(
"table, expected_result",
[
pytest.param(
LookupTable([1, 2, 3]),
"[1, 2, 3]",
),
pytest.param(
LookupTable([LookupTable([1, 2, 3]), LookupTable([4, 5, 6])]),
"[[1, 2, 3], [4, 5, 6]]",
),
],
)
def test_lookup_table_repr(table, expected_result):
"""
Test `__repr__` method of `LookupTable` class.
"""
assert repr(table) == expected_result

View File

@@ -165,6 +165,17 @@ def test_node_bad_call(node, args, expected_error, expected_message):
["%0"],
"tlu(%0, table=[4 1 3 2])",
),
pytest.param(
Node.generic(
name="index.static",
inputs=[EncryptedTensor(UnsignedInteger(3), shape=(3,))],
output=EncryptedTensor(UnsignedInteger(3), shape=(3,)),
operation=lambda x: x[slice(None, None, -1)],
attributes={"index": (slice(None, None, -1),)},
),
["%0"],
"%0[::-1]",
),
pytest.param(
Node.generic(
name="concatenate",