refactor: use kwargs to store index during static indexing

This commit is contained in:
Umut
2022-07-25 15:20:46 +02:00
parent ecb70e2893
commit 46e275d233
4 changed files with 6 additions and 6 deletions

View File

@@ -725,7 +725,7 @@ class NodeConverter:
input_value = self.node.inputs[0]
input_shape = input_value.shape
index = list(self.node.properties["attributes"]["index"])
index = list(self.node.properties["kwargs"]["index"])
while len(index) < input_value.ndim:
index.append(slice(None, None, None))

View File

@@ -248,7 +248,7 @@ class Node:
name = self.properties["name"]
if name == "index.static":
index = self.properties["attributes"]["index"]
index = self.properties["kwargs"]["index"]
elements = [format_indexing_element(element) for element in index]
return f"{predecessors[0]}[{', '.join(elements)}]"
@@ -292,7 +292,7 @@ class Node:
assert_that(self.operation == Operation.Generic)
name = self.properties["name"]
return name if name != "index.static" else self.format(["index"])
return name if name != "index.static" else self.format([""])
@property
def converted_to_table_lookup(self) -> bool:

View File

@@ -666,8 +666,8 @@ class Tracer:
"index.static",
[self.output],
output_value,
lambda x: x[index],
attributes={"index": index},
lambda x, index: x[index],
kwargs={"index": index},
)
return Tracer(computation, [self])

View File

@@ -171,7 +171,7 @@ def test_node_bad_call(node, args, expected_error, expected_message):
inputs=[EncryptedTensor(UnsignedInteger(3), shape=(3,))],
output=EncryptedTensor(UnsignedInteger(3), shape=(3,)),
operation=lambda x: x[slice(None, None, -1)],
attributes={"index": (slice(None, None, -1),)},
kwargs={"index": (slice(None, None, -1),)},
),
["%0"],
"%0[::-1]",