feat(frontend-python): implement len for tracers

This makes fhe.array extension work with tensors as well!
This commit is contained in:
Umut
2023-08-02 13:30:41 +02:00
parent 894ed9ec9f
commit cce0cd882f
5 changed files with 22 additions and 42 deletions

View File

@@ -44,9 +44,7 @@ def array(values: Any) -> Union[np.ndarray, Tracer]:
if not isinstance(value, Tracer):
values[i] = Tracer.sanitize(value)
if not values[i].output.is_scalar:
message = "Encrypted arrays can only be created from scalars"
raise ValueError(message)
assert values[i].output.is_scalar
dtype = combine_dtypes([value.output.dtype for value in values])
is_encrypted = True

View File

@@ -863,6 +863,13 @@ class Tracer:
return Tracer._trace_numpy_operation(np.transpose, self)
def __len__(self):
shape = self.shape
if len(shape) == 0:
message = "object of type 'Tracer' where 'shape == ()' has no len()"
raise TypeError(message)
return shape[0]
class Annotation(Tracer):
"""