fix(compilation): work around slicing limitations

This commit is contained in:
Umut
2022-02-15 15:08:32 +03:00
parent 5bca3b19ae
commit 4bb5a54d6e
2 changed files with 65 additions and 51 deletions

View File

@@ -422,6 +422,8 @@ class IntermediateNodeConverter:
str: textual MLIR representation corresponding to self.node
"""
# pylint: disable=too-many-locals
assert_true(len(self.node.inputs) == 1)
assert_true(len(self.node.outputs) == 1)
@@ -448,10 +450,11 @@ class IntermediateNodeConverter:
sizes = []
strides = []
can_be_converted = True
destroyed_dimensions = []
for dimension, (indexing_element, dimension_size) in enumerate(zip(index, input_shape)):
if isinstance(indexing_element, int):
destroyed_dimensions.append(dimension)
size = 1
stride = 1
offset = (
@@ -466,10 +469,6 @@ class IntermediateNodeConverter:
input_shape,
index_str,
)
if size == 1:
can_be_converted = False
break
stride = indexing_element.step if isinstance(indexing_element.step, int) else 1
offset = (
(
@@ -484,20 +483,43 @@ class IntermediateNodeConverter:
else: # pragma: no cover
# this branch is impossible to reach with all the previous checks
# but let's keep it as an extra measure
can_be_converted = False
break
raise NotImplementedError(
f"Indexing of {input_value} with {index_str} cannot be converted to MLIR",
)
offsets.append(offset)
sizes.append(size)
strides.append(stride)
if not can_be_converted:
raise NotImplementedError(
f"Indexing of {input_value} with {index_str} cannot be converted to MLIR yet",
)
if len(destroyed_dimensions) == 0:
return tensor.ExtractSliceOp(
tensor_type,
pred,
[],
[],
[],
ArrayAttr.get([IntegerAttr.get(index_type, value) for value in offsets]),
ArrayAttr.get([IntegerAttr.get(index_type, value) for value in sizes]),
ArrayAttr.get([IntegerAttr.get(index_type, value) for value in strides]),
).result
return tensor.ExtractSliceOp(
tensor_type,
output_value = cast(TensorValue, self.node.outputs[0])
intermediate_shape = list(output_value.shape)
for dimension in destroyed_dimensions:
intermediate_shape.insert(dimension, 1)
intermediate_type = RankedTensorType.get(
intermediate_shape,
integer_to_mlir_type(
self.ctx,
cast(Integer, output_value.dtype),
output_value.is_encrypted,
),
)
intermediate = tensor.ExtractSliceOp(
intermediate_type,
pred,
[],
[],
@@ -507,6 +529,36 @@ class IntermediateNodeConverter:
ArrayAttr.get([IntegerAttr.get(index_type, value) for value in strides]),
).result
reassociaton = []
current_intermediate_dimension = 0
for _ in range(len(output_value.shape)):
indices = [current_intermediate_dimension]
while current_intermediate_dimension in destroyed_dimensions:
current_intermediate_dimension += 1
indices.append(current_intermediate_dimension)
reassociaton.append(indices)
current_intermediate_dimension += 1
while current_intermediate_dimension < len(intermediate_shape):
reassociaton[-1].append(current_intermediate_dimension)
current_intermediate_dimension += 1
return linalg.TensorCollapseShapeOp(
tensor_type,
intermediate,
ArrayAttr.get(
[
ArrayAttr.get(
[IntegerAttr.get(index_type, index) for index in indices],
)
for indices in reassociaton
],
),
).result
# pylint: enable=too-many-locals
def convert_matmul(self) -> OpResult:
"""Convert a MatMul node to its corresponding MLIR representation.

View File

@@ -688,7 +688,6 @@ def test_invalid_constant_indexing_with_numpy_values(
[np.random.randint(0, 2 ** 6, size=(3, 2)) for _ in range(10)],
([[11, 12], [21, 22], [31, 32]],),
[11, 12],
marks=pytest.mark.xfail(strict=True),
),
pytest.param(
lambda x: x[:, 0],
@@ -698,7 +697,6 @@ def test_invalid_constant_indexing_with_numpy_values(
[np.random.randint(0, 2 ** 6, size=(3, 2)) for _ in range(10)],
([[11, 12], [21, 22], [31, 32]],),
[11, 21, 31],
marks=pytest.mark.xfail(strict=True),
),
],
)
@@ -727,39 +725,3 @@ def test_constant_indexing_run_correctness(
expected = np.array(expected_output, dtype=np.uint8)
check_array_equality(output, expected)
@pytest.mark.parametrize(
"function,parameters,inputset,match",
[
pytest.param(
lambda x: x[0:1],
{
"x": EncryptedTensor(UnsignedInteger(3), shape=(3,)),
},
[np.random.randint(0, 2 ** 3, size=(3,)) for _ in range(10)],
(
"Indexing of EncryptedTensor<uint3, shape=(3,)> with [0:1] "
"cannot be converted to MLIR yet"
),
),
],
)
def test_constant_indexing_failed_compilation(
function,
parameters,
inputset,
match,
default_compilation_configuration,
):
"""Test compilation failures of compiled function with constant indexing"""
with pytest.raises(RuntimeError) as excinfo:
compile_numpy_function(
function,
parameters,
inputset,
default_compilation_configuration,
)
assert str(excinfo.value) == match, str(excinfo.value)