mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
fix: make artifacts truly optional to improve performance for large graphs
This commit is contained in:
@@ -40,7 +40,7 @@ class Compiler:
|
||||
parameter_encryption_statuses: Dict[str, EncryptionStatus]
|
||||
|
||||
configuration: Configuration
|
||||
artifacts: DebugArtifacts
|
||||
artifacts: Optional[DebugArtifacts]
|
||||
|
||||
inputset: List[Any]
|
||||
graph: Optional[Graph]
|
||||
@@ -97,7 +97,7 @@ class Compiler:
|
||||
}
|
||||
|
||||
self.configuration = Configuration()
|
||||
self.artifacts = DebugArtifacts()
|
||||
self.artifacts = None
|
||||
|
||||
self.inputset = []
|
||||
self.graph = None
|
||||
@@ -136,9 +136,10 @@ class Compiler:
|
||||
sample to use for tracing
|
||||
"""
|
||||
|
||||
self.artifacts.add_source_code(self.function)
|
||||
for param, encryption_status in self.parameter_encryption_statuses.items():
|
||||
self.artifacts.add_parameter_encryption_status(param, encryption_status)
|
||||
if self.artifacts is not None:
|
||||
self.artifacts.add_source_code(self.function)
|
||||
for param, encryption_status in self.parameter_encryption_statuses.items():
|
||||
self.artifacts.add_parameter_encryption_status(param, encryption_status)
|
||||
|
||||
parameters = {
|
||||
param: Value.of(arg, is_encrypted=(status == EncryptionStatus.ENCRYPTED))
|
||||
@@ -149,7 +150,8 @@ class Compiler:
|
||||
}
|
||||
|
||||
self.graph = Tracer.trace(self.function, parameters)
|
||||
self.artifacts.add_graph("initial", self.graph)
|
||||
if self.artifacts is not None:
|
||||
self.artifacts.add_graph("initial", self.graph)
|
||||
|
||||
fuse(self.graph, self.artifacts)
|
||||
|
||||
@@ -205,10 +207,12 @@ class Compiler:
|
||||
assert self.graph is not None
|
||||
|
||||
bounds = self.graph.measure_bounds(self.inputset)
|
||||
self.artifacts.add_final_graph_bounds(bounds)
|
||||
if self.artifacts is not None:
|
||||
self.artifacts.add_final_graph_bounds(bounds)
|
||||
|
||||
self.graph.update_with_bounds(bounds)
|
||||
self.artifacts.add_graph("final", self.graph)
|
||||
if self.artifacts is not None:
|
||||
self.artifacts.add_graph("final", self.graph)
|
||||
|
||||
def trace(
|
||||
self,
|
||||
@@ -243,12 +247,18 @@ class Compiler:
|
||||
|
||||
if configuration is not None:
|
||||
self.configuration = configuration
|
||||
if artifacts is not None:
|
||||
self.artifacts = artifacts
|
||||
|
||||
if len(kwargs) != 0:
|
||||
self.configuration = self.configuration.fork(**kwargs)
|
||||
|
||||
self.artifacts = (
|
||||
artifacts
|
||||
if artifacts is not None
|
||||
else DebugArtifacts()
|
||||
if self.configuration.dump_artifacts_on_unexpected_failures
|
||||
else None
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
self._evaluate("Tracing", inputset)
|
||||
@@ -292,6 +302,7 @@ class Compiler:
|
||||
# we need to export all the information we have about the compilation
|
||||
|
||||
if self.configuration.dump_artifacts_on_unexpected_failures:
|
||||
assert self.artifacts is not None
|
||||
self.artifacts.export()
|
||||
|
||||
traceback_path = self.artifacts.output_directory.joinpath("traceback.txt")
|
||||
@@ -340,19 +351,26 @@ class Compiler:
|
||||
|
||||
if configuration is not None:
|
||||
self.configuration = configuration
|
||||
if artifacts is not None:
|
||||
self.artifacts = artifacts
|
||||
|
||||
if len(kwargs) != 0:
|
||||
self.configuration = self.configuration.fork(**kwargs)
|
||||
|
||||
self.artifacts = (
|
||||
artifacts
|
||||
if artifacts is not None
|
||||
else DebugArtifacts()
|
||||
if self.configuration.dump_artifacts_on_unexpected_failures
|
||||
else None
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
self._evaluate("Compiling", inputset)
|
||||
assert self.graph is not None
|
||||
|
||||
mlir = GraphConverter.convert(self.graph, virtual=self.configuration.virtual)
|
||||
self.artifacts.add_mlir_to_compile(mlir)
|
||||
if self.artifacts is not None:
|
||||
self.artifacts.add_mlir_to_compile(mlir)
|
||||
|
||||
if (
|
||||
self.configuration.verbose
|
||||
@@ -412,9 +430,10 @@ class Compiler:
|
||||
circuit = Circuit(self.graph, mlir, self.configuration)
|
||||
if not self.configuration.virtual:
|
||||
assert circuit.client.specs.client_parameters is not None
|
||||
self.artifacts.add_client_parameters(
|
||||
circuit.client.specs.client_parameters.serialize()
|
||||
)
|
||||
if self.artifacts is not None:
|
||||
self.artifacts.add_client_parameters(
|
||||
circuit.client.specs.client_parameters.serialize()
|
||||
)
|
||||
return circuit
|
||||
|
||||
except Exception: # pragma: no cover
|
||||
@@ -426,6 +445,7 @@ class Compiler:
|
||||
# we need to export all the information we have about the compilation
|
||||
|
||||
if self.configuration.dump_artifacts_on_unexpected_failures:
|
||||
assert self.artifacts is not None
|
||||
self.artifacts.export()
|
||||
|
||||
traceback_path = self.artifacts.output_directory.joinpath("traceback.txt")
|
||||
|
||||
@@ -286,11 +286,7 @@ class Node:
|
||||
assert_that(self.operation == Operation.Generic)
|
||||
|
||||
name = self.properties["name"]
|
||||
|
||||
if name == "index.static":
|
||||
name = self.format(["index"])
|
||||
|
||||
return name
|
||||
return name if name != "index.static" else self.format(["index"])
|
||||
|
||||
@property
|
||||
def converted_to_table_lookup(self) -> bool:
|
||||
|
||||
3
tests/extensions/__init__.py
Normal file
3
tests/extensions/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
Tests of extensions.
|
||||
"""
|
||||
28
tests/extensions/test_table.py
Normal file
28
tests/extensions/test_table.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""
|
||||
Tests of LookupTable.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from concrete.numpy import LookupTable
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"table, expected_result",
|
||||
[
|
||||
pytest.param(
|
||||
LookupTable([1, 2, 3]),
|
||||
"[1, 2, 3]",
|
||||
),
|
||||
pytest.param(
|
||||
LookupTable([LookupTable([1, 2, 3]), LookupTable([4, 5, 6])]),
|
||||
"[[1, 2, 3], [4, 5, 6]]",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_lookup_table_repr(table, expected_result):
|
||||
"""
|
||||
Test `__repr__` method of `LookupTable` class.
|
||||
"""
|
||||
|
||||
assert repr(table) == expected_result
|
||||
@@ -165,6 +165,17 @@ def test_node_bad_call(node, args, expected_error, expected_message):
|
||||
["%0"],
|
||||
"tlu(%0, table=[4 1 3 2])",
|
||||
),
|
||||
pytest.param(
|
||||
Node.generic(
|
||||
name="index.static",
|
||||
inputs=[EncryptedTensor(UnsignedInteger(3), shape=(3,))],
|
||||
output=EncryptedTensor(UnsignedInteger(3), shape=(3,)),
|
||||
operation=lambda x: x[slice(None, None, -1)],
|
||||
attributes={"index": (slice(None, None, -1),)},
|
||||
),
|
||||
["%0"],
|
||||
"%0[::-1]",
|
||||
),
|
||||
pytest.param(
|
||||
Node.generic(
|
||||
name="concatenate",
|
||||
|
||||
Reference in New Issue
Block a user