feat(frontend-python): improve indexing error messages

This commit is contained in:
Umut
2023-09-18 15:05:30 +02:00
parent 73f01468e7
commit 51f5ed9484
2 changed files with 43 additions and 7 deletions

View File

@@ -179,6 +179,9 @@ class Tracer:
def __hash__(self) -> int:
return id(self)
def __str__(self) -> str:
return f"Tracer<output={self.output}>"
def __bool__(self) -> bool:
# pylint: disable=invalid-bool-returned
@@ -742,7 +745,12 @@ class Tracer:
int, np.integer, slice, "Tracer", Tuple[Union[int, np.integer, slice, "Tracer"], ...]
],
) -> "Tracer":
if isinstance(index, Tracer) and index.output.is_encrypted and self.output.is_clear:
if (
isinstance(index, Tracer)
and index.output.is_encrypted
and self.output.is_clear
and not self.output.is_scalar
):
computation = Node.generic(
"dynamic_tlu",
[deepcopy(index.output), deepcopy(self.output)],
@@ -754,6 +762,7 @@ class Tracer:
if not isinstance(index, tuple):
index = (index,)
reject = False
for indexing_element in index:
valid = isinstance(indexing_element, (int, np.integer, slice))
@@ -775,10 +784,20 @@ class Tracer:
valid = False
if not valid:
message = (
f"Indexing with '{format_indexing_element(indexing_element)}' is not supported"
)
raise ValueError(message)
reject = True
break
if reject:
indexing_elements = [
format_indexing_element(indexing_element) for indexing_element in index
]
formatted_index = (
indexing_elements[0]
if len(indexing_elements) == 1
else ", ".join(str(element) for element in indexing_elements)
)
message = f"{self} cannot be indexed with {formatted_index}"
raise ValueError(message)
output_value = deepcopy(self.output)
output_value.shape = np.zeros(output_value.shape)[index].shape # type: ignore