mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
fix(compilation): work around slicing limitations
This commit is contained in:
@@ -422,6 +422,8 @@ class IntermediateNodeConverter:
|
||||
str: textual MLIR representation corresponding to self.node
|
||||
"""
|
||||
|
||||
# pylint: disable=too-many-locals
|
||||
|
||||
assert_true(len(self.node.inputs) == 1)
|
||||
assert_true(len(self.node.outputs) == 1)
|
||||
|
||||
@@ -448,10 +450,11 @@ class IntermediateNodeConverter:
|
||||
sizes = []
|
||||
strides = []
|
||||
|
||||
can_be_converted = True
|
||||
destroyed_dimensions = []
|
||||
for dimension, (indexing_element, dimension_size) in enumerate(zip(index, input_shape)):
|
||||
|
||||
if isinstance(indexing_element, int):
|
||||
destroyed_dimensions.append(dimension)
|
||||
size = 1
|
||||
stride = 1
|
||||
offset = (
|
||||
@@ -466,10 +469,6 @@ class IntermediateNodeConverter:
|
||||
input_shape,
|
||||
index_str,
|
||||
)
|
||||
if size == 1:
|
||||
can_be_converted = False
|
||||
break
|
||||
|
||||
stride = indexing_element.step if isinstance(indexing_element.step, int) else 1
|
||||
offset = (
|
||||
(
|
||||
@@ -484,20 +483,43 @@ class IntermediateNodeConverter:
|
||||
else: # pragma: no cover
|
||||
# this branch is impossible to reach with all the previous checks
|
||||
# but let's keep it as an extra measure
|
||||
can_be_converted = False
|
||||
break
|
||||
raise NotImplementedError(
|
||||
f"Indexing of {input_value} with {index_str} cannot be converted to MLIR",
|
||||
)
|
||||
|
||||
offsets.append(offset)
|
||||
sizes.append(size)
|
||||
strides.append(stride)
|
||||
|
||||
if not can_be_converted:
|
||||
raise NotImplementedError(
|
||||
f"Indexing of {input_value} with {index_str} cannot be converted to MLIR yet",
|
||||
)
|
||||
if len(destroyed_dimensions) == 0:
|
||||
return tensor.ExtractSliceOp(
|
||||
tensor_type,
|
||||
pred,
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
ArrayAttr.get([IntegerAttr.get(index_type, value) for value in offsets]),
|
||||
ArrayAttr.get([IntegerAttr.get(index_type, value) for value in sizes]),
|
||||
ArrayAttr.get([IntegerAttr.get(index_type, value) for value in strides]),
|
||||
).result
|
||||
|
||||
return tensor.ExtractSliceOp(
|
||||
tensor_type,
|
||||
output_value = cast(TensorValue, self.node.outputs[0])
|
||||
|
||||
intermediate_shape = list(output_value.shape)
|
||||
for dimension in destroyed_dimensions:
|
||||
intermediate_shape.insert(dimension, 1)
|
||||
|
||||
intermediate_type = RankedTensorType.get(
|
||||
intermediate_shape,
|
||||
integer_to_mlir_type(
|
||||
self.ctx,
|
||||
cast(Integer, output_value.dtype),
|
||||
output_value.is_encrypted,
|
||||
),
|
||||
)
|
||||
|
||||
intermediate = tensor.ExtractSliceOp(
|
||||
intermediate_type,
|
||||
pred,
|
||||
[],
|
||||
[],
|
||||
@@ -507,6 +529,36 @@ class IntermediateNodeConverter:
|
||||
ArrayAttr.get([IntegerAttr.get(index_type, value) for value in strides]),
|
||||
).result
|
||||
|
||||
reassociaton = []
|
||||
|
||||
current_intermediate_dimension = 0
|
||||
for _ in range(len(output_value.shape)):
|
||||
indices = [current_intermediate_dimension]
|
||||
while current_intermediate_dimension in destroyed_dimensions:
|
||||
current_intermediate_dimension += 1
|
||||
indices.append(current_intermediate_dimension)
|
||||
|
||||
reassociaton.append(indices)
|
||||
current_intermediate_dimension += 1
|
||||
while current_intermediate_dimension < len(intermediate_shape):
|
||||
reassociaton[-1].append(current_intermediate_dimension)
|
||||
current_intermediate_dimension += 1
|
||||
|
||||
return linalg.TensorCollapseShapeOp(
|
||||
tensor_type,
|
||||
intermediate,
|
||||
ArrayAttr.get(
|
||||
[
|
||||
ArrayAttr.get(
|
||||
[IntegerAttr.get(index_type, index) for index in indices],
|
||||
)
|
||||
for indices in reassociaton
|
||||
],
|
||||
),
|
||||
).result
|
||||
|
||||
# pylint: enable=too-many-locals
|
||||
|
||||
def convert_matmul(self) -> OpResult:
|
||||
"""Convert a MatMul node to its corresponding MLIR representation.
|
||||
|
||||
|
||||
@@ -688,7 +688,6 @@ def test_invalid_constant_indexing_with_numpy_values(
|
||||
[np.random.randint(0, 2 ** 6, size=(3, 2)) for _ in range(10)],
|
||||
([[11, 12], [21, 22], [31, 32]],),
|
||||
[11, 12],
|
||||
marks=pytest.mark.xfail(strict=True),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x[:, 0],
|
||||
@@ -698,7 +697,6 @@ def test_invalid_constant_indexing_with_numpy_values(
|
||||
[np.random.randint(0, 2 ** 6, size=(3, 2)) for _ in range(10)],
|
||||
([[11, 12], [21, 22], [31, 32]],),
|
||||
[11, 21, 31],
|
||||
marks=pytest.mark.xfail(strict=True),
|
||||
),
|
||||
],
|
||||
)
|
||||
@@ -727,39 +725,3 @@ def test_constant_indexing_run_correctness(
|
||||
expected = np.array(expected_output, dtype=np.uint8)
|
||||
|
||||
check_array_equality(output, expected)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,parameters,inputset,match",
|
||||
[
|
||||
pytest.param(
|
||||
lambda x: x[0:1],
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(3), shape=(3,)),
|
||||
},
|
||||
[np.random.randint(0, 2 ** 3, size=(3,)) for _ in range(10)],
|
||||
(
|
||||
"Indexing of EncryptedTensor<uint3, shape=(3,)> with [0:1] "
|
||||
"cannot be converted to MLIR yet"
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_constant_indexing_failed_compilation(
|
||||
function,
|
||||
parameters,
|
||||
inputset,
|
||||
match,
|
||||
default_compilation_configuration,
|
||||
):
|
||||
"""Test compilation failures of compiled function with constant indexing"""
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
compile_numpy_function(
|
||||
function,
|
||||
parameters,
|
||||
inputset,
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
assert str(excinfo.value) == match, str(excinfo.value)
|
||||
|
||||
Reference in New Issue
Block a user