feat(mlir): implement mlir conversion of reshape and flatten

This commit is contained in:
Umut
2021-12-02 13:34:21 +03:00
parent 7fcab62cd5
commit c8ec2a2340
8 changed files with 409 additions and 35 deletions

View File

@@ -1675,23 +1675,6 @@ function you are trying to compile isn't supported for MLIR lowering
%0 = x # EncryptedTensor<uint3, shape=(3, 2)>
%1 = ravel(%0) # EncryptedTensor<uint3, shape=(6,)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ravel is not supported for the time being
return %1
""".strip() # noqa: E501
),
),
pytest.param(
lambda x: numpy.reshape(x, (2, 6)),
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(3, 4))},
[numpy.random.randint(0, 2 ** 3, size=(3, 4)) for i in range(10)],
(
"""
function you are trying to compile isn't supported for MLIR lowering
%0 = x # EncryptedTensor<uint3, shape=(3, 4)>
%1 = reshape(%0, newshape=(2, 6)) # EncryptedTensor<uint3, shape=(2, 6)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ reshape is not supported for the time being
return %1
""".strip() # noqa: E501

View File

@@ -0,0 +1,265 @@
"""Test module for memory operations."""
import numpy
import pytest
from concrete.common.data_types import UnsignedInteger
from concrete.common.values import EncryptedTensor
from concrete.numpy import compile_numpy_function
@pytest.mark.parametrize(
"function,parameters,inputset,test_input,expected_output",
[
pytest.param(
lambda x: x.flatten(),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(3, 2)),
},
[numpy.random.randint(0, 2 ** 4, size=(3, 2)) for _ in range(10)],
[[0, 1], [1, 2], [2, 3]],
[0, 1, 1, 2, 2, 3],
),
pytest.param(
lambda x: x.flatten(),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 4, 5, 6)),
},
[numpy.random.randint(0, 2 ** 4, size=(2, 3, 4, 5, 6)) for _ in range(10)],
(numpy.arange(720) % 10).reshape((2, 3, 4, 5, 6)),
(numpy.arange(720) % 10),
),
pytest.param(
lambda x: x.reshape((1, 3)),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(3,)),
},
[numpy.random.randint(0, 2 ** 4, size=(3,)) for _ in range(10)],
[5, 9, 1],
[[5, 9, 1]],
),
pytest.param(
lambda x: x.reshape((3, 1)),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(3,)),
},
[numpy.random.randint(0, 2 ** 4, size=(3,)) for _ in range(10)],
[5, 9, 1],
[[5], [9], [1]],
),
pytest.param(
lambda x: x.reshape((3, 2)),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(3, 2)),
},
[numpy.random.randint(0, 2 ** 4, size=(3, 2)) for _ in range(10)],
[[0, 1], [1, 2], [2, 3]],
[[0, 1], [1, 2], [2, 3]],
),
pytest.param(
lambda x: x.reshape((3, 2)),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3)),
},
[numpy.random.randint(0, 2 ** 4, size=(2, 3)) for _ in range(10)],
[[0, 1, 1], [2, 2, 3]],
[[0, 1], [1, 2], [2, 3]],
),
pytest.param(
lambda x: x.reshape(-1),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(3, 2)),
},
[numpy.random.randint(0, 2 ** 4, size=(3, 2)) for _ in range(10)],
[[0, 1], [1, 2], [2, 3]],
[0, 1, 1, 2, 2, 3],
),
pytest.param(
lambda x: x.reshape((2, 2, 3)),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(4, 3)),
},
[numpy.random.randint(0, 2 ** 4, size=(4, 3)) for _ in range(10)],
(numpy.arange(12) % 10).reshape((4, 3)),
(numpy.arange(12) % 10).reshape((2, 2, 3)),
),
pytest.param(
lambda x: x.reshape((4, 3)),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 2, 3)),
},
[numpy.random.randint(0, 2 ** 4, size=(2, 2, 3)) for _ in range(10)],
(numpy.arange(12) % 10).reshape((2, 2, 3)),
(numpy.arange(12) % 10).reshape((4, 3)),
),
pytest.param(
lambda x: x.reshape((3, 2, 2)),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(3, 4)),
},
[numpy.random.randint(0, 2 ** 4, size=(3, 4)) for _ in range(10)],
(numpy.arange(12) % 10).reshape((3, 4)),
(numpy.arange(12) % 10).reshape((3, 2, 2)),
),
pytest.param(
lambda x: x.reshape((3, 4)),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(3, 2, 2)),
},
[numpy.random.randint(0, 2 ** 4, size=(3, 2, 2)) for _ in range(10)],
(numpy.arange(12) % 10).reshape((3, 2, 2)),
(numpy.arange(12) % 10).reshape((3, 4)),
),
pytest.param(
lambda x: x.reshape((5, 3, 2)),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(6, 5)),
},
[numpy.random.randint(0, 2 ** 4, size=(6, 5)) for _ in range(10)],
(numpy.arange(30) % 10).reshape((6, 5)),
(numpy.arange(30) % 10).reshape((5, 3, 2)),
),
pytest.param(
lambda x: x.reshape((5, 6)),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 5)),
},
[numpy.random.randint(0, 2 ** 4, size=(2, 3, 5)) for _ in range(10)],
(numpy.arange(30) % 10).reshape((2, 3, 5)),
(numpy.arange(30) % 10).reshape((5, 6)),
),
pytest.param(
lambda x: x.reshape((6, 4, 30)),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 4, 5, 6)),
},
[numpy.random.randint(0, 2 ** 4, size=(2, 3, 4, 5, 6)) for _ in range(10)],
(numpy.arange(720) % 10).reshape((2, 3, 4, 5, 6)),
(numpy.arange(720) % 10).reshape((6, 4, 30)),
),
pytest.param(
lambda x: x.reshape((2, 60, 6)),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 4, 5, 6)),
},
[numpy.random.randint(0, 2 ** 4, size=(2, 3, 4, 5, 6)) for _ in range(10)],
(numpy.arange(720) % 10).reshape((2, 3, 4, 5, 6)),
(numpy.arange(720) % 10).reshape((2, 60, 6)),
),
pytest.param(
lambda x: x.reshape((6, 6, -1)),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 2, 3, 4)),
},
[numpy.random.randint(0, 2 ** 4, size=(2, 3, 2, 3, 4)) for _ in range(10)],
(numpy.arange(144) % 10).reshape((2, 3, 2, 3, 4)),
(numpy.arange(144) % 10).reshape((6, 6, -1)),
),
pytest.param(
lambda x: x.reshape((6, -1, 12)),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 2, 3, 4)),
},
[numpy.random.randint(0, 2 ** 4, size=(2, 3, 2, 3, 4)) for _ in range(10)],
(numpy.arange(144) % 10).reshape((2, 3, 2, 3, 4)),
(numpy.arange(144) % 10).reshape((6, -1, 12)),
),
pytest.param(
lambda x: x.reshape((-1, 18, 4)),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 2, 3, 4)),
},
[numpy.random.randint(0, 2 ** 4, size=(2, 3, 2, 3, 4)) for _ in range(10)],
(numpy.arange(144) % 10).reshape((2, 3, 2, 3, 4)),
(numpy.arange(144) % 10).reshape((-1, 18, 4)),
),
],
)
def test_memory_operation_run_correctness(
function,
parameters,
inputset,
test_input,
expected_output,
default_compilation_configuration,
check_array_equality,
):
"""
Test correctness of results when running a compiled function with memory operators.
e.g.,
- flatten
- reshape
"""
circuit = compile_numpy_function(
function,
parameters,
inputset,
default_compilation_configuration,
)
actual = circuit.run(numpy.array(test_input, dtype=numpy.uint8))
expected = numpy.array(expected_output, dtype=numpy.uint8)
check_array_equality(actual, expected)
@pytest.mark.parametrize(
"function,parameters,inputset,error,match",
[
pytest.param(
lambda x: x.reshape((-1, -1, 2)),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 4)),
},
[numpy.random.randint(0, 2 ** 4, size=(2, 3, 4)) for _ in range(10)],
ValueError,
"shapes are not compatible (old shape (2, 3, 4), new shape (-1, -1, 2))",
),
pytest.param(
lambda x: x.reshape((3, -1, 3)),
{
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 4)),
},
[numpy.random.randint(0, 2 ** 4, size=(2, 3, 4)) for _ in range(10)],
ValueError,
"shapes are not compatible (old shape (2, 3, 4), new shape (3, -1, 3))",
),
],
)
def test_memory_operation_failed_compilation(
function,
parameters,
inputset,
error,
match,
default_compilation_configuration,
):
"""
Test compilation failures of compiled function with memory operations.
e.g.,
- reshape
"""
with pytest.raises(error) as excinfo:
compile_numpy_function(
function,
parameters,
inputset,
default_compilation_configuration,
)
assert (
str(excinfo.value) == match
), f"""
Actual Output
=============
{excinfo.value}
Expected Output
===============
{match}
"""

View File

@@ -644,7 +644,7 @@ def test_tracing_numpy_calls(
None,
)
],
marks=pytest.mark.xfail(strict=True, raises=AssertionError),
marks=pytest.mark.xfail(strict=True, raises=ValueError),
),
pytest.param(
lambda x: x.flatten(),
@@ -985,7 +985,7 @@ def test_tracing_ndarray_calls(
)
def test_errors_with_generic_function(lambda_f, params):
"Test some errors with generic function"
with pytest.raises(AssertionError) as excinfo:
with pytest.raises(ValueError) as excinfo:
tracing.trace_numpy_function(lambda_f, params)
assert "shapes are not compatible (old shape (7, 5), new shape (5, 3))" in str(excinfo.value)

View File

@@ -294,7 +294,7 @@ def test_tracing_numpy_calls(
None,
)
],
marks=pytest.mark.xfail(strict=True, raises=AssertionError),
marks=pytest.mark.xfail(strict=True, raises=ValueError),
),
],
)

View File

@@ -178,7 +178,7 @@ def test_nptracer_unsupported_operands(operation, exception_type, match):
)
def test_errors_with_generic_function(lambda_f, params):
"Test some errors with generic function"
with pytest.raises(AssertionError) as excinfo:
with pytest.raises(ValueError) as excinfo:
tracing.trace_numpy_function(lambda_f, params)
assert "shapes are not compatible (old shape (7, 5), new shape (5, 3))" in str(excinfo.value)