mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-10 04:35:03 -05:00
refactor: use kwargs to store index during static indexing
This commit is contained in:
@@ -725,7 +725,7 @@ class NodeConverter:
|
||||
input_value = self.node.inputs[0]
|
||||
input_shape = input_value.shape
|
||||
|
||||
index = list(self.node.properties["attributes"]["index"])
|
||||
index = list(self.node.properties["kwargs"]["index"])
|
||||
|
||||
while len(index) < input_value.ndim:
|
||||
index.append(slice(None, None, None))
|
||||
|
||||
@@ -248,7 +248,7 @@ class Node:
|
||||
name = self.properties["name"]
|
||||
|
||||
if name == "index.static":
|
||||
index = self.properties["attributes"]["index"]
|
||||
index = self.properties["kwargs"]["index"]
|
||||
elements = [format_indexing_element(element) for element in index]
|
||||
return f"{predecessors[0]}[{', '.join(elements)}]"
|
||||
|
||||
@@ -292,7 +292,7 @@ class Node:
|
||||
assert_that(self.operation == Operation.Generic)
|
||||
|
||||
name = self.properties["name"]
|
||||
return name if name != "index.static" else self.format(["index"])
|
||||
return name if name != "index.static" else self.format(["□"])
|
||||
|
||||
@property
|
||||
def converted_to_table_lookup(self) -> bool:
|
||||
|
||||
@@ -666,8 +666,8 @@ class Tracer:
|
||||
"index.static",
|
||||
[self.output],
|
||||
output_value,
|
||||
lambda x: x[index],
|
||||
attributes={"index": index},
|
||||
lambda x, index: x[index],
|
||||
kwargs={"index": index},
|
||||
)
|
||||
return Tracer(computation, [self])
|
||||
|
||||
|
||||
@@ -171,7 +171,7 @@ def test_node_bad_call(node, args, expected_error, expected_message):
|
||||
inputs=[EncryptedTensor(UnsignedInteger(3), shape=(3,))],
|
||||
output=EncryptedTensor(UnsignedInteger(3), shape=(3,)),
|
||||
operation=lambda x: x[slice(None, None, -1)],
|
||||
attributes={"index": (slice(None, None, -1),)},
|
||||
kwargs={"index": (slice(None, None, -1),)},
|
||||
),
|
||||
["%0"],
|
||||
"%0[::-1]",
|
||||
|
||||
Reference in New Issue
Block a user