feat(mlir): implement mlir conversion of reshape and flatten

This commit is contained in:
Umut
2021-12-02 13:34:21 +03:00
parent 7fcab62cd5
commit c8ec2a2340
8 changed files with 409 additions and 35 deletions

View File

@@ -7,7 +7,7 @@
from typing import Any, Dict, List, Tuple, cast
import numpy
from mlir.dialects import arith, tensor
from mlir.dialects import arith, linalg, tensor
from mlir.ir import (
ArrayAttr,
Attribute,
@@ -117,7 +117,11 @@ class IntermediateNodeConverter:
result = self.convert_dot()
elif isinstance(self.node, GenericFunction):
result = self.convert_generic_function(additional_conversion_info)
if self.node.op_name in ["flatten", "reshape"]:
# notice flatten() == reshape(-1) and convert_reshape can handle that
result = self.convert_reshape()
else:
result = self.convert_generic_function(additional_conversion_info)
elif isinstance(self.node, IndexConstant):
result = self.convert_index_constant()
@@ -481,6 +485,129 @@ class IntermediateNodeConverter:
return result
def convert_reshape(self) -> OpResult:
"""Convert a "reshape" node to its corresponding MLIR representation.
Returns:
str: textual MLIR representation corresponding to self.node
"""
assert_true(len(self.node.inputs) == 1)
assert_true(len(self.node.outputs) == 1)
assert_true(isinstance(self.node.inputs[0], TensorValue))
input_shape = cast(TensorValue, self.node.inputs[0]).shape
assert_true(isinstance(self.node.outputs[0], TensorValue))
output_shape = cast(TensorValue, self.node.outputs[0]).shape
pred = self.preds[0]
if input_shape == output_shape:
return pred
# we can either collapse or expand, which changes the number of dimensions
# this is a limitation of the current compiler and it will be improved in the future (#1060)
can_be_converted_directly = len(input_shape) != len(output_shape)
reassociation: List[List[int]] = []
if can_be_converted_directly:
if len(output_shape) == 1:
# output is 1 dimensional so collapse every dimension into the same dimension
reassociation.append(list(range(len(input_shape))))
else:
# input is m dimensional
# output is n dimensional
# and m is different than n
# we don't want to duplicate code so we forget about input and output
# and we focus on smaller shape and bigger shape
smaller_shape, bigger_shape = (
(output_shape, input_shape)
if len(output_shape) < len(input_shape)
else (input_shape, output_shape)
)
s_index, b_index = 0, 0
# now we will figure out how to group the bigger shape to get the smaller shape
# think of the algorithm below as
# keep merging the dimensions of the bigger shape
# until we have a match on the smaller shape
# then try to match the next dimension of the smaller shape
# if all dimensions of the smaller shape is matched
# we can convert it
group = []
size = 1
while s_index < len(smaller_shape) and b_index < len(bigger_shape):
# dimension `b_index` of `bigger_shape` belongs to current group
group.append(b_index)
# and current group has `size * bigger_shape[b_index]` elements now
size *= bigger_shape[b_index]
# if current group size matches the dimension `s_index` of `smaller_shape`
if size == smaller_shape[s_index]:
# we finalize this group and reset everything
size = 1
reassociation.append(group)
group = []
# now try to match the next dimension of `smaller_shape`
s_index += 1
# now process the next dimension of `bigger_shape`
b_index += 1
# handle the case where bigger shape has proceeding 1s
# e.g., (5,) -> (5, 1)
while b_index < len(bigger_shape) and bigger_shape[b_index] == 1:
reassociation[-1].append(b_index)
b_index += 1
# if not all dimensions of both shapes are processed exactly
if s_index != len(smaller_shape) or b_index != len(bigger_shape):
# we cannot convert
can_be_converted_directly = False
index_type = IndexType.parse("index")
resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0])
if can_be_converted_directly:
reassociation_attr = ArrayAttr.get(
[
ArrayAttr.get([IntegerAttr.get(index_type, dimension) for dimension in group])
for group in reassociation
]
)
if len(output_shape) < len(input_shape):
return linalg.TensorCollapseShapeOp(resulting_type, pred, reassociation_attr).result
return linalg.TensorExpandShapeOp(resulting_type, pred, reassociation_attr).result
flattened_type = value_to_mlir_type(
self.ctx,
TensorValue(
self.node.inputs[0].dtype,
self.node.inputs[0].is_encrypted,
(numpy.prod(input_shape),),
),
)
flattened_result = linalg.TensorCollapseShapeOp(
flattened_type,
pred,
ArrayAttr.get(
[ArrayAttr.get([IntegerAttr.get(index_type, i) for i in range(len(input_shape))])]
),
).result
return linalg.TensorExpandShapeOp(
resulting_type,
flattened_result,
ArrayAttr.get(
[ArrayAttr.get([IntegerAttr.get(index_type, i) for i in range(len(output_shape))])]
),
).result
def convert_sub(self) -> OpResult:
"""Convert a Sub node to its corresponding MLIR representation.

View File

@@ -89,7 +89,8 @@ def check_node_compatibility_with_mlir(
== 1
)
else:
return f"{node.op_name} is not supported for the time being"
if node.op_name not in ["flatten", "reshape"]:
return f"{node.op_name} is not supported for the time being"
elif isinstance(node, intermediate.Dot): # constraints for dot product
assert_true(len(inputs) == 2)

View File

@@ -362,13 +362,13 @@ class NPTracer(BaseTracer):
)
return output_tracer
def reshape(self, arg1: Tuple[Any, ...], **kwargs) -> "NPTracer":
def reshape(self, newshape: Tuple[Any, ...], **kwargs) -> "NPTracer":
"""Trace x.reshape.
Returns:
NPTracer: The output NPTracer containing the traced function
"""
return self.numpy_reshape(self, arg1, **kwargs)
return self.numpy_reshape(self, newshape, **kwargs)
def numpy_reshape(self, arg0: "NPTracer", arg1: Tuple[Any, ...], **kwargs) -> "NPTracer":
"""Trace numpy.reshape.
@@ -390,15 +390,13 @@ class NPTracer(BaseTracer):
assert_true(isinstance(first_arg_output, TensorValue))
first_arg_output = cast(TensorValue, first_arg_output)
# Make numpy.reshape(x, (170)) and numpy.reshape(x, 170) work,
# while classical form is numpy.reshape(x, (170,))
newshape = deepcopy(arg1) if not isinstance(arg1, int) else (arg1,)
# Check shape compatibility
assert_true(
numpy.product(newshape) == first_arg_output.size,
f"shapes are not compatible (old shape {first_arg_output.shape}, new shape {newshape})",
)
try:
# calculate a newshape using numpy to handle edge cases such as `-1`s within new shape
newshape = numpy.zeros(first_arg_output.shape).reshape(arg1).shape
except Exception as error:
raise ValueError(
f"shapes are not compatible (old shape {first_arg_output.shape}, new shape {arg1})"
) from error
reshape_is_fusable = newshape == first_arg_output.shape

View File

@@ -1675,23 +1675,6 @@ function you are trying to compile isn't supported for MLIR lowering
%0 = x # EncryptedTensor<uint3, shape=(3, 2)>
%1 = ravel(%0) # EncryptedTensor<uint3, shape=(6,)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ravel is not supported for the time being
return %1
""".strip() # noqa: E501
),
),
pytest.param(
lambda x: numpy.reshape(x, (2, 6)),
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(3, 4))},
[numpy.random.randint(0, 2 ** 3, size=(3, 4)) for i in range(10)],
(
"""
function you are trying to compile isn't supported for MLIR lowering
%0 = x # EncryptedTensor<uint3, shape=(3, 4)>
%1 = reshape(%0, newshape=(2, 6)) # EncryptedTensor<uint3, shape=(2, 6)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ reshape is not supported for the time being
return %1
""".strip() # noqa: E501

View File

@@ -0,0 +1,265 @@
"""Test module for memory operations."""
import numpy
import pytest
from concrete.common.data_types import UnsignedInteger
from concrete.common.values import EncryptedTensor
from concrete.numpy import compile_numpy_function
@pytest.mark.parametrize(
"function,parameters,inputset,test_input,expected_output",
[
pytest.param(
lambda x: x.flatten(),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(3, 2)),
},
[numpy.random.randint(0, 2 ** 4, size=(3, 2)) for _ in range(10)],
[[0, 1], [1, 2], [2, 3]],
[0, 1, 1, 2, 2, 3],
),
pytest.param(
lambda x: x.flatten(),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 4, 5, 6)),
},
[numpy.random.randint(0, 2 ** 4, size=(2, 3, 4, 5, 6)) for _ in range(10)],
(numpy.arange(720) % 10).reshape((2, 3, 4, 5, 6)),
(numpy.arange(720) % 10),
),
pytest.param(
lambda x: x.reshape((1, 3)),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(3,)),
},
[numpy.random.randint(0, 2 ** 4, size=(3,)) for _ in range(10)],
[5, 9, 1],
[[5, 9, 1]],
),
pytest.param(
lambda x: x.reshape((3, 1)),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(3,)),
},
[numpy.random.randint(0, 2 ** 4, size=(3,)) for _ in range(10)],
[5, 9, 1],
[[5], [9], [1]],
),
pytest.param(
lambda x: x.reshape((3, 2)),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(3, 2)),
},
[numpy.random.randint(0, 2 ** 4, size=(3, 2)) for _ in range(10)],
[[0, 1], [1, 2], [2, 3]],
[[0, 1], [1, 2], [2, 3]],
),
pytest.param(
lambda x: x.reshape((3, 2)),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3)),
},
[numpy.random.randint(0, 2 ** 4, size=(2, 3)) for _ in range(10)],
[[0, 1, 1], [2, 2, 3]],
[[0, 1], [1, 2], [2, 3]],
),
pytest.param(
lambda x: x.reshape(-1),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(3, 2)),
},
[numpy.random.randint(0, 2 ** 4, size=(3, 2)) for _ in range(10)],
[[0, 1], [1, 2], [2, 3]],
[0, 1, 1, 2, 2, 3],
),
pytest.param(
lambda x: x.reshape((2, 2, 3)),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(4, 3)),
},
[numpy.random.randint(0, 2 ** 4, size=(4, 3)) for _ in range(10)],
(numpy.arange(12) % 10).reshape((4, 3)),
(numpy.arange(12) % 10).reshape((2, 2, 3)),
),
pytest.param(
lambda x: x.reshape((4, 3)),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 2, 3)),
},
[numpy.random.randint(0, 2 ** 4, size=(2, 2, 3)) for _ in range(10)],
(numpy.arange(12) % 10).reshape((2, 2, 3)),
(numpy.arange(12) % 10).reshape((4, 3)),
),
pytest.param(
lambda x: x.reshape((3, 2, 2)),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(3, 4)),
},
[numpy.random.randint(0, 2 ** 4, size=(3, 4)) for _ in range(10)],
(numpy.arange(12) % 10).reshape((3, 4)),
(numpy.arange(12) % 10).reshape((3, 2, 2)),
),
pytest.param(
lambda x: x.reshape((3, 4)),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(3, 2, 2)),
},
[numpy.random.randint(0, 2 ** 4, size=(3, 2, 2)) for _ in range(10)],
(numpy.arange(12) % 10).reshape((3, 2, 2)),
(numpy.arange(12) % 10).reshape((3, 4)),
),
pytest.param(
lambda x: x.reshape((5, 3, 2)),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(6, 5)),
},
[numpy.random.randint(0, 2 ** 4, size=(6, 5)) for _ in range(10)],
(numpy.arange(30) % 10).reshape((6, 5)),
(numpy.arange(30) % 10).reshape((5, 3, 2)),
),
pytest.param(
lambda x: x.reshape((5, 6)),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 5)),
},
[numpy.random.randint(0, 2 ** 4, size=(2, 3, 5)) for _ in range(10)],
(numpy.arange(30) % 10).reshape((2, 3, 5)),
(numpy.arange(30) % 10).reshape((5, 6)),
),
pytest.param(
lambda x: x.reshape((6, 4, 30)),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 4, 5, 6)),
},
[numpy.random.randint(0, 2 ** 4, size=(2, 3, 4, 5, 6)) for _ in range(10)],
(numpy.arange(720) % 10).reshape((2, 3, 4, 5, 6)),
(numpy.arange(720) % 10).reshape((6, 4, 30)),
),
pytest.param(
lambda x: x.reshape((2, 60, 6)),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 4, 5, 6)),
},
[numpy.random.randint(0, 2 ** 4, size=(2, 3, 4, 5, 6)) for _ in range(10)],
(numpy.arange(720) % 10).reshape((2, 3, 4, 5, 6)),
(numpy.arange(720) % 10).reshape((2, 60, 6)),
),
pytest.param(
lambda x: x.reshape((6, 6, -1)),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 2, 3, 4)),
},
[numpy.random.randint(0, 2 ** 4, size=(2, 3, 2, 3, 4)) for _ in range(10)],
(numpy.arange(144) % 10).reshape((2, 3, 2, 3, 4)),
(numpy.arange(144) % 10).reshape((6, 6, -1)),
),
pytest.param(
lambda x: x.reshape((6, -1, 12)),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 2, 3, 4)),
},
[numpy.random.randint(0, 2 ** 4, size=(2, 3, 2, 3, 4)) for _ in range(10)],
(numpy.arange(144) % 10).reshape((2, 3, 2, 3, 4)),
(numpy.arange(144) % 10).reshape((6, -1, 12)),
),
pytest.param(
lambda x: x.reshape((-1, 18, 4)),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 2, 3, 4)),
},
[numpy.random.randint(0, 2 ** 4, size=(2, 3, 2, 3, 4)) for _ in range(10)],
(numpy.arange(144) % 10).reshape((2, 3, 2, 3, 4)),
(numpy.arange(144) % 10).reshape((-1, 18, 4)),
),
],
)
def test_memory_operation_run_correctness(
function,
parameters,
inputset,
test_input,
expected_output,
default_compilation_configuration,
check_array_equality,
):
"""
Test correctness of results when running a compiled function with memory operators.
e.g.,
- flatten
- reshape
"""
circuit = compile_numpy_function(
function,
parameters,
inputset,
default_compilation_configuration,
)
actual = circuit.run(numpy.array(test_input, dtype=numpy.uint8))
expected = numpy.array(expected_output, dtype=numpy.uint8)
check_array_equality(actual, expected)
@pytest.mark.parametrize(
"function,parameters,inputset,error,match",
[
pytest.param(
lambda x: x.reshape((-1, -1, 2)),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 4)),
},
[numpy.random.randint(0, 2 ** 4, size=(2, 3, 4)) for _ in range(10)],
ValueError,
"shapes are not compatible (old shape (2, 3, 4), new shape (-1, -1, 2))",
),
pytest.param(
lambda x: x.reshape((3, -1, 3)),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 4)),
},
[numpy.random.randint(0, 2 ** 4, size=(2, 3, 4)) for _ in range(10)],
ValueError,
"shapes are not compatible (old shape (2, 3, 4), new shape (3, -1, 3))",
),
],
)
def test_memory_operation_failed_compilation(
function,
parameters,
inputset,
error,
match,
default_compilation_configuration,
):
"""
Test compilation failures of compiled function with memory operations.
e.g.,
- reshape
"""
with pytest.raises(error) as excinfo:
compile_numpy_function(
function,
parameters,
inputset,
default_compilation_configuration,
)
assert (
str(excinfo.value) == match
), f"""
Actual Output
=============
{excinfo.value}
Expected Output
===============
{match}
"""

View File

@@ -644,7 +644,7 @@ def test_tracing_numpy_calls(
None,
)
],
marks=pytest.mark.xfail(strict=True, raises=AssertionError),
marks=pytest.mark.xfail(strict=True, raises=ValueError),
),
pytest.param(
lambda x: x.flatten(),
@@ -985,7 +985,7 @@ def test_tracing_ndarray_calls(
)
def test_errors_with_generic_function(lambda_f, params):
"Test some errors with generic function"
with pytest.raises(AssertionError) as excinfo:
with pytest.raises(ValueError) as excinfo:
tracing.trace_numpy_function(lambda_f, params)
assert "shapes are not compatible (old shape (7, 5), new shape (5, 3))" in str(excinfo.value)

View File

@@ -294,7 +294,7 @@ def test_tracing_numpy_calls(
None,
)
],
marks=pytest.mark.xfail(strict=True, raises=AssertionError),
marks=pytest.mark.xfail(strict=True, raises=ValueError),
),
],
)

View File

@@ -178,7 +178,7 @@ def test_nptracer_unsupported_operands(operation, exception_type, match):
)
def test_errors_with_generic_function(lambda_f, params):
"Test some errors with generic function"
with pytest.raises(AssertionError) as excinfo:
with pytest.raises(ValueError) as excinfo:
tracing.trace_numpy_function(lambda_f, params)
assert "shapes are not compatible (old shape (7, 5), new shape (5, 3))" in str(excinfo.value)