mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
feat(mlir): implement mlir conversion of reshape and flatten
This commit is contained in:
@@ -1675,23 +1675,6 @@ function you are trying to compile isn't supported for MLIR lowering
|
||||
%0 = x # EncryptedTensor<uint3, shape=(3, 2)>
|
||||
%1 = ravel(%0) # EncryptedTensor<uint3, shape=(6,)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ravel is not supported for the time being
|
||||
return %1
|
||||
|
||||
""".strip() # noqa: E501
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: numpy.reshape(x, (2, 6)),
|
||||
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(3, 4))},
|
||||
[numpy.random.randint(0, 2 ** 3, size=(3, 4)) for i in range(10)],
|
||||
(
|
||||
"""
|
||||
|
||||
function you are trying to compile isn't supported for MLIR lowering
|
||||
|
||||
%0 = x # EncryptedTensor<uint3, shape=(3, 4)>
|
||||
%1 = reshape(%0, newshape=(2, 6)) # EncryptedTensor<uint3, shape=(2, 6)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ reshape is not supported for the time being
|
||||
return %1
|
||||
|
||||
""".strip() # noqa: E501
|
||||
|
||||
265
tests/numpy/test_compile_memory_operations.py
Normal file
265
tests/numpy/test_compile_memory_operations.py
Normal file
@@ -0,0 +1,265 @@
|
||||
"""Test module for memory operations."""
|
||||
|
||||
import numpy
|
||||
import pytest
|
||||
|
||||
from concrete.common.data_types import UnsignedInteger
|
||||
from concrete.common.values import EncryptedTensor
|
||||
from concrete.numpy import compile_numpy_function
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,parameters,inputset,test_input,expected_output",
|
||||
[
|
||||
pytest.param(
|
||||
lambda x: x.flatten(),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(3, 2)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(3, 2)) for _ in range(10)],
|
||||
[[0, 1], [1, 2], [2, 3]],
|
||||
[0, 1, 1, 2, 2, 3],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.flatten(),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 4, 5, 6)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(2, 3, 4, 5, 6)) for _ in range(10)],
|
||||
(numpy.arange(720) % 10).reshape((2, 3, 4, 5, 6)),
|
||||
(numpy.arange(720) % 10),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.reshape((1, 3)),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(3,)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(3,)) for _ in range(10)],
|
||||
[5, 9, 1],
|
||||
[[5, 9, 1]],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.reshape((3, 1)),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(3,)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(3,)) for _ in range(10)],
|
||||
[5, 9, 1],
|
||||
[[5], [9], [1]],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.reshape((3, 2)),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(3, 2)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(3, 2)) for _ in range(10)],
|
||||
[[0, 1], [1, 2], [2, 3]],
|
||||
[[0, 1], [1, 2], [2, 3]],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.reshape((3, 2)),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(2, 3)) for _ in range(10)],
|
||||
[[0, 1, 1], [2, 2, 3]],
|
||||
[[0, 1], [1, 2], [2, 3]],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.reshape(-1),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(3, 2)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(3, 2)) for _ in range(10)],
|
||||
[[0, 1], [1, 2], [2, 3]],
|
||||
[0, 1, 1, 2, 2, 3],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.reshape((2, 2, 3)),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(4, 3)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(4, 3)) for _ in range(10)],
|
||||
(numpy.arange(12) % 10).reshape((4, 3)),
|
||||
(numpy.arange(12) % 10).reshape((2, 2, 3)),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.reshape((4, 3)),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 2, 3)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(2, 2, 3)) for _ in range(10)],
|
||||
(numpy.arange(12) % 10).reshape((2, 2, 3)),
|
||||
(numpy.arange(12) % 10).reshape((4, 3)),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.reshape((3, 2, 2)),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(3, 4)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(3, 4)) for _ in range(10)],
|
||||
(numpy.arange(12) % 10).reshape((3, 4)),
|
||||
(numpy.arange(12) % 10).reshape((3, 2, 2)),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.reshape((3, 4)),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(3, 2, 2)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(3, 2, 2)) for _ in range(10)],
|
||||
(numpy.arange(12) % 10).reshape((3, 2, 2)),
|
||||
(numpy.arange(12) % 10).reshape((3, 4)),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.reshape((5, 3, 2)),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(6, 5)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(6, 5)) for _ in range(10)],
|
||||
(numpy.arange(30) % 10).reshape((6, 5)),
|
||||
(numpy.arange(30) % 10).reshape((5, 3, 2)),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.reshape((5, 6)),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 5)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(2, 3, 5)) for _ in range(10)],
|
||||
(numpy.arange(30) % 10).reshape((2, 3, 5)),
|
||||
(numpy.arange(30) % 10).reshape((5, 6)),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.reshape((6, 4, 30)),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 4, 5, 6)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(2, 3, 4, 5, 6)) for _ in range(10)],
|
||||
(numpy.arange(720) % 10).reshape((2, 3, 4, 5, 6)),
|
||||
(numpy.arange(720) % 10).reshape((6, 4, 30)),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.reshape((2, 60, 6)),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 4, 5, 6)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(2, 3, 4, 5, 6)) for _ in range(10)],
|
||||
(numpy.arange(720) % 10).reshape((2, 3, 4, 5, 6)),
|
||||
(numpy.arange(720) % 10).reshape((2, 60, 6)),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.reshape((6, 6, -1)),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 2, 3, 4)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(2, 3, 2, 3, 4)) for _ in range(10)],
|
||||
(numpy.arange(144) % 10).reshape((2, 3, 2, 3, 4)),
|
||||
(numpy.arange(144) % 10).reshape((6, 6, -1)),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.reshape((6, -1, 12)),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 2, 3, 4)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(2, 3, 2, 3, 4)) for _ in range(10)],
|
||||
(numpy.arange(144) % 10).reshape((2, 3, 2, 3, 4)),
|
||||
(numpy.arange(144) % 10).reshape((6, -1, 12)),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.reshape((-1, 18, 4)),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 2, 3, 4)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(2, 3, 2, 3, 4)) for _ in range(10)],
|
||||
(numpy.arange(144) % 10).reshape((2, 3, 2, 3, 4)),
|
||||
(numpy.arange(144) % 10).reshape((-1, 18, 4)),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_memory_operation_run_correctness(
|
||||
function,
|
||||
parameters,
|
||||
inputset,
|
||||
test_input,
|
||||
expected_output,
|
||||
default_compilation_configuration,
|
||||
check_array_equality,
|
||||
):
|
||||
"""
|
||||
Test correctness of results when running a compiled function with memory operators.
|
||||
|
||||
e.g.,
|
||||
- flatten
|
||||
- reshape
|
||||
"""
|
||||
circuit = compile_numpy_function(
|
||||
function,
|
||||
parameters,
|
||||
inputset,
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
actual = circuit.run(numpy.array(test_input, dtype=numpy.uint8))
|
||||
expected = numpy.array(expected_output, dtype=numpy.uint8)
|
||||
|
||||
check_array_equality(actual, expected)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,parameters,inputset,error,match",
|
||||
[
|
||||
pytest.param(
|
||||
lambda x: x.reshape((-1, -1, 2)),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 4)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(2, 3, 4)) for _ in range(10)],
|
||||
ValueError,
|
||||
"shapes are not compatible (old shape (2, 3, 4), new shape (-1, -1, 2))",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.reshape((3, -1, 3)),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(4), shape=(2, 3, 4)),
|
||||
},
|
||||
[numpy.random.randint(0, 2 ** 4, size=(2, 3, 4)) for _ in range(10)],
|
||||
ValueError,
|
||||
"shapes are not compatible (old shape (2, 3, 4), new shape (3, -1, 3))",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_memory_operation_failed_compilation(
|
||||
function,
|
||||
parameters,
|
||||
inputset,
|
||||
error,
|
||||
match,
|
||||
default_compilation_configuration,
|
||||
):
|
||||
"""
|
||||
Test compilation failures of compiled function with memory operations.
|
||||
|
||||
e.g.,
|
||||
- reshape
|
||||
"""
|
||||
|
||||
with pytest.raises(error) as excinfo:
|
||||
compile_numpy_function(
|
||||
function,
|
||||
parameters,
|
||||
inputset,
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
assert (
|
||||
str(excinfo.value) == match
|
||||
), f"""
|
||||
|
||||
Actual Output
|
||||
=============
|
||||
{excinfo.value}
|
||||
|
||||
Expected Output
|
||||
===============
|
||||
{match}
|
||||
|
||||
"""
|
||||
@@ -644,7 +644,7 @@ def test_tracing_numpy_calls(
|
||||
None,
|
||||
)
|
||||
],
|
||||
marks=pytest.mark.xfail(strict=True, raises=AssertionError),
|
||||
marks=pytest.mark.xfail(strict=True, raises=ValueError),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.flatten(),
|
||||
@@ -985,7 +985,7 @@ def test_tracing_ndarray_calls(
|
||||
)
|
||||
def test_errors_with_generic_function(lambda_f, params):
|
||||
"Test some errors with generic function"
|
||||
with pytest.raises(AssertionError) as excinfo:
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
tracing.trace_numpy_function(lambda_f, params)
|
||||
|
||||
assert "shapes are not compatible (old shape (7, 5), new shape (5, 3))" in str(excinfo.value)
|
||||
|
||||
@@ -294,7 +294,7 @@ def test_tracing_numpy_calls(
|
||||
None,
|
||||
)
|
||||
],
|
||||
marks=pytest.mark.xfail(strict=True, raises=AssertionError),
|
||||
marks=pytest.mark.xfail(strict=True, raises=ValueError),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
@@ -178,7 +178,7 @@ def test_nptracer_unsupported_operands(operation, exception_type, match):
|
||||
)
|
||||
def test_errors_with_generic_function(lambda_f, params):
|
||||
"Test some errors with generic function"
|
||||
with pytest.raises(AssertionError) as excinfo:
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
tracing.trace_numpy_function(lambda_f, params)
|
||||
|
||||
assert "shapes are not compatible (old shape (7, 5), new shape (5, 3))" in str(excinfo.value)
|
||||
|
||||
Reference in New Issue
Block a user