mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 20:25:34 -05:00
feat(frontend-python): implement len for tracers
This makes fhe.array extension work with tensors as well!
This commit is contained in:
@@ -44,9 +44,7 @@ def array(values: Any) -> Union[np.ndarray, Tracer]:
|
||||
if not isinstance(value, Tracer):
|
||||
values[i] = Tracer.sanitize(value)
|
||||
|
||||
if not values[i].output.is_scalar:
|
||||
message = "Encrypted arrays can only be created from scalars"
|
||||
raise ValueError(message)
|
||||
assert values[i].output.is_scalar
|
||||
|
||||
dtype = combine_dtypes([value.output.dtype for value in values])
|
||||
is_encrypted = True
|
||||
|
||||
@@ -863,6 +863,13 @@ class Tracer:
|
||||
|
||||
return Tracer._trace_numpy_operation(np.transpose, self)
|
||||
|
||||
def __len__(self):
|
||||
shape = self.shape
|
||||
if len(shape) == 0:
|
||||
message = "object of type 'Tracer' where 'shape == ()' has no len()"
|
||||
raise TypeError(message)
|
||||
return shape[0]
|
||||
|
||||
|
||||
class Annotation(Tracer):
|
||||
"""
|
||||
|
||||
@@ -155,6 +155,18 @@ def test_compiler_bad_trace(helpers):
|
||||
|
||||
assert str(excinfo.value) == "Function 'g' returned '[{} ()]', which is not supported"
|
||||
|
||||
# len on scalar
|
||||
# -------------
|
||||
|
||||
def len_on_scalar(x):
|
||||
return len(x)
|
||||
|
||||
with pytest.raises(TypeError) as excinfo:
|
||||
compiler = Compiler(len_on_scalar, {"x": "encrypted"})
|
||||
compiler.trace(inputset=[1, 2, 3], configuration=configuration)
|
||||
|
||||
assert str(excinfo.value) == "object of type 'Tracer' where 'shape == ()' has no len()"
|
||||
|
||||
|
||||
def test_compiler_bad_compile(helpers):
|
||||
"""
|
||||
|
||||
@@ -526,11 +526,11 @@ def copy_modify(x):
|
||||
id="fusable_additional_2",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x + x.shape[0] + x.ndim + x.size,
|
||||
lambda x: x + x.shape[0] + x.ndim + x.size + len(x),
|
||||
{
|
||||
"x": {"status": "encrypted", "range": [0, 15], "shape": (3, 2)},
|
||||
},
|
||||
id="x + x.shape[0] + x.ndim + x.size",
|
||||
id="x + x.shape[0] + x.ndim + x.size + len(x)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: (50 * np.sin(x.transpose())).astype(np.int64),
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
"""
|
||||
Tests of 'array' extension.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from concrete import fhe
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,parameters,expected_error",
|
||||
[
|
||||
pytest.param(
|
||||
lambda x, y: fhe.array([x, y]),
|
||||
{
|
||||
"x": {"range": [0, 10], "status": "encrypted", "shape": ()},
|
||||
"y": {"range": [0, 10], "status": "encrypted", "shape": (2, 3)},
|
||||
},
|
||||
"Encrypted arrays can only be created from scalars",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_bad_array(function, parameters, expected_error, helpers):
|
||||
"""
|
||||
Test array with bad parameters.
|
||||
"""
|
||||
|
||||
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
|
||||
configuration = helpers.configuration()
|
||||
|
||||
compiler = fhe.Compiler(function, parameter_encryption_statuses)
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
inputset = helpers.generate_inputset(parameters)
|
||||
compiler.compile(inputset, configuration)
|
||||
|
||||
assert str(excinfo.value) == expected_error
|
||||
Reference in New Issue
Block a user