diff --git a/concrete/numpy/compilation/compiler.py b/concrete/numpy/compilation/compiler.py index f48254b4f..333fd1200 100644 --- a/concrete/numpy/compilation/compiler.py +++ b/concrete/numpy/compilation/compiler.py @@ -40,7 +40,7 @@ class Compiler: parameter_encryption_statuses: Dict[str, EncryptionStatus] configuration: Configuration - artifacts: DebugArtifacts + artifacts: Optional[DebugArtifacts] inputset: List[Any] graph: Optional[Graph] @@ -97,7 +97,7 @@ class Compiler: } self.configuration = Configuration() - self.artifacts = DebugArtifacts() + self.artifacts = None self.inputset = [] self.graph = None @@ -136,9 +136,10 @@ class Compiler: sample to use for tracing """ - self.artifacts.add_source_code(self.function) - for param, encryption_status in self.parameter_encryption_statuses.items(): - self.artifacts.add_parameter_encryption_status(param, encryption_status) + if self.artifacts is not None: + self.artifacts.add_source_code(self.function) + for param, encryption_status in self.parameter_encryption_statuses.items(): + self.artifacts.add_parameter_encryption_status(param, encryption_status) parameters = { param: Value.of(arg, is_encrypted=(status == EncryptionStatus.ENCRYPTED)) @@ -149,7 +150,8 @@ class Compiler: } self.graph = Tracer.trace(self.function, parameters) - self.artifacts.add_graph("initial", self.graph) + if self.artifacts is not None: + self.artifacts.add_graph("initial", self.graph) fuse(self.graph, self.artifacts) @@ -205,10 +207,12 @@ class Compiler: assert self.graph is not None bounds = self.graph.measure_bounds(self.inputset) - self.artifacts.add_final_graph_bounds(bounds) + if self.artifacts is not None: + self.artifacts.add_final_graph_bounds(bounds) self.graph.update_with_bounds(bounds) - self.artifacts.add_graph("final", self.graph) + if self.artifacts is not None: + self.artifacts.add_graph("final", self.graph) def trace( self, @@ -243,12 +247,18 @@ class Compiler: if configuration is not None: self.configuration = configuration - if artifacts is not None: - self.artifacts = artifacts if len(kwargs) != 0: self.configuration = self.configuration.fork(**kwargs) + self.artifacts = ( + artifacts + if artifacts is not None + else DebugArtifacts() + if self.configuration.dump_artifacts_on_unexpected_failures + else None + ) + try: self._evaluate("Tracing", inputset) @@ -292,6 +302,7 @@ class Compiler: # we need to export all the information we have about the compilation if self.configuration.dump_artifacts_on_unexpected_failures: + assert self.artifacts is not None self.artifacts.export() traceback_path = self.artifacts.output_directory.joinpath("traceback.txt") @@ -340,19 +351,26 @@ class Compiler: if configuration is not None: self.configuration = configuration - if artifacts is not None: - self.artifacts = artifacts if len(kwargs) != 0: self.configuration = self.configuration.fork(**kwargs) + self.artifacts = ( + artifacts + if artifacts is not None + else DebugArtifacts() + if self.configuration.dump_artifacts_on_unexpected_failures + else None + ) + try: self._evaluate("Compiling", inputset) assert self.graph is not None mlir = GraphConverter.convert(self.graph, virtual=self.configuration.virtual) - self.artifacts.add_mlir_to_compile(mlir) + if self.artifacts is not None: + self.artifacts.add_mlir_to_compile(mlir) if ( self.configuration.verbose @@ -412,9 +430,10 @@ class Compiler: circuit = Circuit(self.graph, mlir, self.configuration) if not self.configuration.virtual: assert circuit.client.specs.client_parameters is not None - self.artifacts.add_client_parameters( - circuit.client.specs.client_parameters.serialize() - ) + if self.artifacts is not None: + self.artifacts.add_client_parameters( + circuit.client.specs.client_parameters.serialize() + ) return circuit except Exception: # pragma: no cover @@ -426,6 +445,7 @@ class Compiler: # we need to export all the information we have about the compilation if self.configuration.dump_artifacts_on_unexpected_failures: + assert self.artifacts is not None self.artifacts.export() traceback_path = self.artifacts.output_directory.joinpath("traceback.txt") diff --git a/concrete/numpy/representation/node.py b/concrete/numpy/representation/node.py index a8dcf5cb4..91a2d4b46 100644 --- a/concrete/numpy/representation/node.py +++ b/concrete/numpy/representation/node.py @@ -286,11 +286,7 @@ class Node: assert_that(self.operation == Operation.Generic) name = self.properties["name"] - - if name == "index.static": - name = self.format(["index"]) - - return name + return name if name != "index.static" else self.format(["index"]) @property def converted_to_table_lookup(self) -> bool: diff --git a/tests/extensions/__init__.py b/tests/extensions/__init__.py new file mode 100644 index 000000000..af9e23ae7 --- /dev/null +++ b/tests/extensions/__init__.py @@ -0,0 +1,3 @@ +""" +Tests of extensions. +""" diff --git a/tests/extensions/test_table.py b/tests/extensions/test_table.py new file mode 100644 index 000000000..188311994 --- /dev/null +++ b/tests/extensions/test_table.py @@ -0,0 +1,28 @@ +""" +Tests of LookupTable. +""" + +import pytest + +from concrete.numpy import LookupTable + + +@pytest.mark.parametrize( + "table, expected_result", + [ + pytest.param( + LookupTable([1, 2, 3]), + "[1, 2, 3]", + ), + pytest.param( + LookupTable([LookupTable([1, 2, 3]), LookupTable([4, 5, 6])]), + "[[1, 2, 3], [4, 5, 6]]", + ), + ], +) +def test_lookup_table_repr(table, expected_result): + """ + Test `__repr__` method of `LookupTable` class. + """ + + assert repr(table) == expected_result diff --git a/tests/representation/test_node.py b/tests/representation/test_node.py index 550236cf4..b854e7f9a 100644 --- a/tests/representation/test_node.py +++ b/tests/representation/test_node.py @@ -165,6 +165,17 @@ def test_node_bad_call(node, args, expected_error, expected_message): ["%0"], "tlu(%0, table=[4 1 3 2])", ), + pytest.param( + Node.generic( + name="index.static", + inputs=[EncryptedTensor(UnsignedInteger(3), shape=(3,))], + output=EncryptedTensor(UnsignedInteger(3), shape=(3,)), + operation=lambda x: x[slice(None, None, -1)], + attributes={"index": (slice(None, None, -1),)}, + ), + ["%0"], + "%0[::-1]", + ), pytest.param( Node.generic( name="concatenate",