From 46e275d233f8695921a55c41706b06da868a8c2c Mon Sep 17 00:00:00 2001 From: Umut Date: Mon, 25 Jul 2022 15:20:46 +0200 Subject: [PATCH] refactor: use kwargs to store index during static indexing --- concrete/numpy/mlir/node_converter.py | 2 +- concrete/numpy/representation/node.py | 4 ++-- concrete/numpy/tracing/tracer.py | 4 ++-- tests/representation/test_node.py | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/concrete/numpy/mlir/node_converter.py b/concrete/numpy/mlir/node_converter.py index 6ce5a0b87..667f6184f 100644 --- a/concrete/numpy/mlir/node_converter.py +++ b/concrete/numpy/mlir/node_converter.py @@ -725,7 +725,7 @@ class NodeConverter: input_value = self.node.inputs[0] input_shape = input_value.shape - index = list(self.node.properties["attributes"]["index"]) + index = list(self.node.properties["kwargs"]["index"]) while len(index) < input_value.ndim: index.append(slice(None, None, None)) diff --git a/concrete/numpy/representation/node.py b/concrete/numpy/representation/node.py index 91066ae8e..5f2b6de00 100644 --- a/concrete/numpy/representation/node.py +++ b/concrete/numpy/representation/node.py @@ -248,7 +248,7 @@ class Node: name = self.properties["name"] if name == "index.static": - index = self.properties["attributes"]["index"] + index = self.properties["kwargs"]["index"] elements = [format_indexing_element(element) for element in index] return f"{predecessors[0]}[{', '.join(elements)}]" @@ -292,7 +292,7 @@ class Node: assert_that(self.operation == Operation.Generic) name = self.properties["name"] - return name if name != "index.static" else self.format(["index"]) + return name if name != "index.static" else self.format(["□"]) @property def converted_to_table_lookup(self) -> bool: diff --git a/concrete/numpy/tracing/tracer.py b/concrete/numpy/tracing/tracer.py index a5b996d08..606c4801e 100644 --- a/concrete/numpy/tracing/tracer.py +++ b/concrete/numpy/tracing/tracer.py @@ -666,8 +666,8 @@ class Tracer: "index.static", [self.output], output_value, - lambda x: x[index], - attributes={"index": index}, + lambda x, index: x[index], + kwargs={"index": index}, ) return Tracer(computation, [self]) diff --git a/tests/representation/test_node.py b/tests/representation/test_node.py index 301d478b2..1ff27f38e 100644 --- a/tests/representation/test_node.py +++ b/tests/representation/test_node.py @@ -171,7 +171,7 @@ def test_node_bad_call(node, args, expected_error, expected_message): inputs=[EncryptedTensor(UnsignedInteger(3), shape=(3,))], output=EncryptedTensor(UnsignedInteger(3), shape=(3,)), operation=lambda x: x[slice(None, None, -1)], - attributes={"index": (slice(None, None, -1),)}, + kwargs={"index": (slice(None, None, -1),)}, ), ["%0"], "%0[::-1]",