mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
feat(mlir): implement mlir conversion of basic tensor operations
This commit is contained in:
@@ -22,8 +22,8 @@ def no_fuse_unhandled(x, y):
|
||||
"""No fuse unhandled"""
|
||||
x_intermediate = x + 2.8
|
||||
y_intermediate = y + 9.3
|
||||
intermediate = x_intermediate + y_intermediate
|
||||
return intermediate.astype(numpy.int32)
|
||||
intermediate = x_intermediate - y_intermediate
|
||||
return (intermediate * 1.5).astype(numpy.int32)
|
||||
|
||||
|
||||
def identity_lut_generator(n):
|
||||
@@ -540,6 +540,228 @@ def test_compile_and_run_correctness(
|
||||
assert compiler_engine.run(*args) == function(*args)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,parameters,inputset,test_input,expected_output",
|
||||
[
|
||||
pytest.param(
|
||||
lambda x: x + 1,
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2)),
|
||||
},
|
||||
[(numpy.random.randint(0, 2 ** 3, size=(3, 2)),) for _ in range(10)],
|
||||
(
|
||||
[
|
||||
[0, 7],
|
||||
[6, 1],
|
||||
[2, 5],
|
||||
],
|
||||
),
|
||||
[
|
||||
[1, 8],
|
||||
[7, 2],
|
||||
[3, 6],
|
||||
],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x + numpy.array([[1, 0], [2, 0], [3, 1]], dtype=numpy.uint32),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2)),
|
||||
},
|
||||
[(numpy.random.randint(0, 2 ** 3, size=(3, 2)),) for _ in range(10)],
|
||||
(
|
||||
[
|
||||
[0, 7],
|
||||
[6, 1],
|
||||
[2, 5],
|
||||
],
|
||||
),
|
||||
[
|
||||
[1, 7],
|
||||
[8, 1],
|
||||
[5, 6],
|
||||
],
|
||||
),
|
||||
# TODO: find a way to support this case
|
||||
# https://github.com/zama-ai/concretefhe-internal/issues/837
|
||||
#
|
||||
# the problem is that compiler doesn't support combining scalars and tensors
|
||||
# but they do support broadcasting, so scalars can be converted to (1,) shaped tensors
|
||||
# this is easy with known constants but weird with variable things such as another input
|
||||
# there is tensor.from_elements but I coudn't figure out how to use it in the python API
|
||||
pytest.param(
|
||||
lambda x, y: x + y,
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2)),
|
||||
"y": EncryptedScalar(UnsignedInteger(3)),
|
||||
},
|
||||
[
|
||||
(
|
||||
numpy.random.randint(0, 2 ** 3, size=(3, 2)),
|
||||
random.randint(0, (2 ** 3) - 1),
|
||||
)
|
||||
for _ in range(10)
|
||||
],
|
||||
(
|
||||
[
|
||||
[0, 7],
|
||||
[6, 1],
|
||||
[2, 5],
|
||||
],
|
||||
2,
|
||||
),
|
||||
[
|
||||
[2, 9],
|
||||
[8, 3],
|
||||
[4, 7],
|
||||
],
|
||||
marks=pytest.mark.xfail(),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x, y: x + y,
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2)),
|
||||
"y": EncryptedTensor(UnsignedInteger(3), shape=(3, 2)),
|
||||
},
|
||||
[
|
||||
(
|
||||
numpy.random.randint(0, 2 ** 3, size=(3, 2)),
|
||||
numpy.random.randint(0, 2 ** 3, size=(3, 2)),
|
||||
)
|
||||
for _ in range(10)
|
||||
],
|
||||
(
|
||||
[
|
||||
[0, 7],
|
||||
[6, 1],
|
||||
[2, 5],
|
||||
],
|
||||
[
|
||||
[1, 6],
|
||||
[2, 5],
|
||||
[3, 4],
|
||||
],
|
||||
),
|
||||
[
|
||||
[1, 13],
|
||||
[8, 6],
|
||||
[5, 9],
|
||||
],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: 100 - x,
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2)),
|
||||
},
|
||||
[(numpy.random.randint(0, 2 ** 3, size=(3, 2)),) for _ in range(10)],
|
||||
(
|
||||
[
|
||||
[0, 7],
|
||||
[6, 1],
|
||||
[2, 5],
|
||||
],
|
||||
),
|
||||
[
|
||||
[100, 93],
|
||||
[94, 99],
|
||||
[98, 95],
|
||||
],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: numpy.array([[10, 15], [20, 15], [10, 30]], dtype=numpy.uint32) - x,
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2)),
|
||||
},
|
||||
[(numpy.random.randint(0, 2 ** 3, size=(3, 2)),) for _ in range(10)],
|
||||
(
|
||||
[
|
||||
[0, 7],
|
||||
[6, 1],
|
||||
[2, 5],
|
||||
],
|
||||
),
|
||||
[
|
||||
[10, 8],
|
||||
[14, 14],
|
||||
[8, 25],
|
||||
],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x * 2,
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2)),
|
||||
},
|
||||
[(numpy.random.randint(0, 2 ** 3, size=(3, 2)),) for _ in range(10)],
|
||||
(
|
||||
[
|
||||
[0, 7],
|
||||
[6, 1],
|
||||
[2, 5],
|
||||
],
|
||||
),
|
||||
[
|
||||
[0, 14],
|
||||
[12, 2],
|
||||
[4, 10],
|
||||
],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x * numpy.array([[1, 2], [2, 1], [3, 1]], dtype=numpy.uint32),
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2)),
|
||||
},
|
||||
[(numpy.random.randint(0, 2 ** 3, size=(3, 2)),) for _ in range(10)],
|
||||
(
|
||||
[
|
||||
[4, 7],
|
||||
[6, 1],
|
||||
[2, 5],
|
||||
],
|
||||
),
|
||||
[
|
||||
[4, 14],
|
||||
[12, 1],
|
||||
[6, 5],
|
||||
],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: LookupTable([2, 1, 3, 0])[x],
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(2), shape=(3, 2)),
|
||||
},
|
||||
[(numpy.random.randint(0, 2 ** 2, size=(3, 2)),) for _ in range(10)],
|
||||
(
|
||||
[
|
||||
[0, 1],
|
||||
[2, 1],
|
||||
[3, 0],
|
||||
],
|
||||
),
|
||||
[
|
||||
[2, 1],
|
||||
[3, 1],
|
||||
[0, 2],
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_compile_and_run_tensor_correctness(
|
||||
function, parameters, inputset, test_input, expected_output, default_compilation_configuration
|
||||
):
|
||||
"""Test correctness of results when running a compiled function with tensor operators"""
|
||||
circuit = compile_numpy_function(
|
||||
function,
|
||||
parameters,
|
||||
inputset,
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
numpy_test_input = (numpy.array(item, dtype=numpy.uint8) for item in test_input)
|
||||
assert numpy.array_equal(
|
||||
circuit.run(*numpy_test_input),
|
||||
numpy.array(expected_output, dtype=numpy.uint8),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"size, input_range",
|
||||
[
|
||||
@@ -769,71 +991,7 @@ function you are trying to compile isn't supported for MLIR lowering
|
||||
%0 = Constant(1) # ClearScalar<Integer<unsigned, 1 bits>>
|
||||
%1 = x # EncryptedScalar<Integer<unsigned, 3 bits>>
|
||||
%2 = Sub(%0, %1) # EncryptedScalar<Integer<signed, 4 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only scalar unsigned integer outputs are supported
|
||||
return(%2)
|
||||
""".lstrip() # noqa: E501
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x + 1,
|
||||
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(2, 2))},
|
||||
[(numpy.random.randint(0, 8, size=(2, 2)),) for i in range(10)],
|
||||
(
|
||||
"""
|
||||
function you are trying to compile isn't supported for MLIR lowering
|
||||
|
||||
%0 = x # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(2, 2)>
|
||||
%1 = Constant(1) # ClearScalar<Integer<unsigned, 1 bits>>
|
||||
%2 = Add(%0, %1) # EncryptedTensor<Integer<unsigned, 4 bits>, shape=(2, 2)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only scalar addition is supported
|
||||
return(%2)
|
||||
""".lstrip() # noqa: E501
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x + 1,
|
||||
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(2, 2))},
|
||||
[(numpy.random.randint(0, 2 ** 3, size=(2, 2)),) for i in range(10)],
|
||||
(
|
||||
"""
|
||||
function you are trying to compile isn't supported for MLIR lowering
|
||||
|
||||
%0 = x # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(2, 2)>
|
||||
%1 = Constant(1) # ClearScalar<Integer<unsigned, 1 bits>>
|
||||
%2 = Add(%0, %1) # EncryptedTensor<Integer<unsigned, 4 bits>, shape=(2, 2)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only scalar addition is supported
|
||||
return(%2)
|
||||
""".lstrip() # noqa: E501
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x * 1,
|
||||
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(2, 2))},
|
||||
[(numpy.random.randint(0, 2 ** 3, size=(2, 2)),) for i in range(10)],
|
||||
(
|
||||
"""
|
||||
function you are trying to compile isn't supported for MLIR lowering
|
||||
|
||||
%0 = x # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(2, 2)>
|
||||
%1 = Constant(1) # ClearScalar<Integer<unsigned, 1 bits>>
|
||||
%2 = Mul(%0, %1) # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(2, 2)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only scalar multiplication is supported
|
||||
return(%2)
|
||||
""".lstrip() # noqa: E501
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: 127 - x,
|
||||
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(2, 2))},
|
||||
[(numpy.random.randint(0, 2 ** 3, size=(2, 2)),) for i in range(10)],
|
||||
(
|
||||
"""
|
||||
function you are trying to compile isn't supported for MLIR lowering
|
||||
|
||||
%0 = Constant(127) # ClearScalar<Integer<unsigned, 7 bits>>
|
||||
%1 = x # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(2, 2)>
|
||||
%2 = Sub(%0, %1) # EncryptedTensor<Integer<unsigned, 7 bits>, shape=(2, 2)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only scalar subtraction is supported
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer outputs are supported
|
||||
return(%2)
|
||||
""".lstrip() # noqa: E501
|
||||
),
|
||||
@@ -892,22 +1050,27 @@ return(%1)
|
||||
[(numpy.array(i), numpy.array(i)) for i in range(10)],
|
||||
(
|
||||
"""
|
||||
function you are trying to compile isn't supported for MLIR lowering\n
|
||||
%0 = x # EncryptedScalar<Integer<unsigned, 4 bits>>
|
||||
%1 = Constant(2.8) # ClearScalar<Float<64 bits>>
|
||||
function you are trying to compile isn't supported for MLIR lowering
|
||||
|
||||
%0 = Constant(1.5) # ClearScalar<Float<64 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer constants are supported
|
||||
%2 = y # EncryptedScalar<Integer<unsigned, 4 bits>>
|
||||
%3 = Constant(9.3) # ClearScalar<Float<64 bits>>
|
||||
%1 = x # EncryptedScalar<Integer<unsigned, 4 bits>>
|
||||
%2 = Constant(2.8) # ClearScalar<Float<64 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer constants are supported
|
||||
%4 = Add(%0, %1) # EncryptedScalar<Float<64 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer intermediates are supported
|
||||
%5 = Add(%2, %3) # EncryptedScalar<Float<64 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer intermediates are supported
|
||||
%6 = Add(%4, %5) # EncryptedScalar<Float<64 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer intermediates are supported
|
||||
%7 = astype(int32)(%6) # EncryptedScalar<Integer<unsigned, 5 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer scalar lookup tables are supported
|
||||
return(%7)
|
||||
%3 = y # EncryptedScalar<Integer<unsigned, 4 bits>>
|
||||
%4 = Constant(9.3) # ClearScalar<Float<64 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer constants are supported
|
||||
%5 = Add(%1, %2) # EncryptedScalar<Float<64 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer addition is supported
|
||||
%6 = Add(%3, %4) # EncryptedScalar<Float<64 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer addition is supported
|
||||
%7 = Sub(%5, %6) # EncryptedScalar<Float<64 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer subtraction is supported
|
||||
%8 = Mul(%7, %0) # EncryptedScalar<Float<64 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer multiplication is supported
|
||||
%9 = astype(int32)(%8) # EncryptedScalar<Integer<signed, 5 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ astype(int32) is not supported for the time being
|
||||
return(%9)
|
||||
""".lstrip() # noqa: E501
|
||||
),
|
||||
),
|
||||
@@ -1057,7 +1220,7 @@ function you are trying to compile isn't supported for MLIR lowering
|
||||
%3 = np.negative(%2) # EncryptedScalar<Integer<signed, 3 bits>>
|
||||
%4 = Mul(%3, %1) # EncryptedScalar<Integer<signed, 6 bits>>
|
||||
%5 = np.absolute(%4) # EncryptedScalar<Integer<unsigned, 5 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer scalar lookup tables are supported
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ np.absolute is not supported for the time being
|
||||
%6 = astype(int32)(%5) # EncryptedScalar<Integer<unsigned, 5 bits>>
|
||||
%7 = Add(%6, %0) # EncryptedScalar<Integer<unsigned, 6 bits>>
|
||||
return(%7)
|
||||
@@ -1255,7 +1418,7 @@ function you are trying to compile isn't supported for MLIR lowering
|
||||
%0 = x # EncryptedScalar<Integer<unsigned, 4 bits>>
|
||||
%1 = Constant(-3) # ClearScalar<Integer<signed, 3 bits>>
|
||||
%2 = Add(%0, %1) # EncryptedScalar<Integer<signed, 4 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only scalar unsigned integer outputs are supported
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer outputs are supported
|
||||
return(%2)
|
||||
""".lstrip() # noqa: E501 # pylint: disable=line-too-long
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user