feat(mlir): implement mlir conversion of basic tensor operations

This commit is contained in:
Umut
2021-11-05 12:30:20 +03:00
parent e20ad467e3
commit ee202a03b3
5 changed files with 502 additions and 156 deletions

View File

@@ -22,8 +22,8 @@ def no_fuse_unhandled(x, y):
"""No fuse unhandled"""
x_intermediate = x + 2.8
y_intermediate = y + 9.3
intermediate = x_intermediate + y_intermediate
return intermediate.astype(numpy.int32)
intermediate = x_intermediate - y_intermediate
return (intermediate * 1.5).astype(numpy.int32)
def identity_lut_generator(n):
@@ -540,6 +540,228 @@ def test_compile_and_run_correctness(
assert compiler_engine.run(*args) == function(*args)
@pytest.mark.parametrize(
"function,parameters,inputset,test_input,expected_output",
[
pytest.param(
lambda x: x + 1,
{
"x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2)),
},
[(numpy.random.randint(0, 2 ** 3, size=(3, 2)),) for _ in range(10)],
(
[
[0, 7],
[6, 1],
[2, 5],
],
),
[
[1, 8],
[7, 2],
[3, 6],
],
),
pytest.param(
lambda x: x + numpy.array([[1, 0], [2, 0], [3, 1]], dtype=numpy.uint32),
{
"x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2)),
},
[(numpy.random.randint(0, 2 ** 3, size=(3, 2)),) for _ in range(10)],
(
[
[0, 7],
[6, 1],
[2, 5],
],
),
[
[1, 7],
[8, 1],
[5, 6],
],
),
# TODO: find a way to support this case
# https://github.com/zama-ai/concretefhe-internal/issues/837
#
# the problem is that compiler doesn't support combining scalars and tensors
# but they do support broadcasting, so scalars can be converted to (1,) shaped tensors
# this is easy with known constants but weird with variable things such as another input
# there is tensor.from_elements but I coudn't figure out how to use it in the python API
pytest.param(
lambda x, y: x + y,
{
"x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2)),
"y": EncryptedScalar(UnsignedInteger(3)),
},
[
(
numpy.random.randint(0, 2 ** 3, size=(3, 2)),
random.randint(0, (2 ** 3) - 1),
)
for _ in range(10)
],
(
[
[0, 7],
[6, 1],
[2, 5],
],
2,
),
[
[2, 9],
[8, 3],
[4, 7],
],
marks=pytest.mark.xfail(),
),
pytest.param(
lambda x, y: x + y,
{
"x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2)),
"y": EncryptedTensor(UnsignedInteger(3), shape=(3, 2)),
},
[
(
numpy.random.randint(0, 2 ** 3, size=(3, 2)),
numpy.random.randint(0, 2 ** 3, size=(3, 2)),
)
for _ in range(10)
],
(
[
[0, 7],
[6, 1],
[2, 5],
],
[
[1, 6],
[2, 5],
[3, 4],
],
),
[
[1, 13],
[8, 6],
[5, 9],
],
),
pytest.param(
lambda x: 100 - x,
{
"x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2)),
},
[(numpy.random.randint(0, 2 ** 3, size=(3, 2)),) for _ in range(10)],
(
[
[0, 7],
[6, 1],
[2, 5],
],
),
[
[100, 93],
[94, 99],
[98, 95],
],
),
pytest.param(
lambda x: numpy.array([[10, 15], [20, 15], [10, 30]], dtype=numpy.uint32) - x,
{
"x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2)),
},
[(numpy.random.randint(0, 2 ** 3, size=(3, 2)),) for _ in range(10)],
(
[
[0, 7],
[6, 1],
[2, 5],
],
),
[
[10, 8],
[14, 14],
[8, 25],
],
),
pytest.param(
lambda x: x * 2,
{
"x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2)),
},
[(numpy.random.randint(0, 2 ** 3, size=(3, 2)),) for _ in range(10)],
(
[
[0, 7],
[6, 1],
[2, 5],
],
),
[
[0, 14],
[12, 2],
[4, 10],
],
),
pytest.param(
lambda x: x * numpy.array([[1, 2], [2, 1], [3, 1]], dtype=numpy.uint32),
{
"x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2)),
},
[(numpy.random.randint(0, 2 ** 3, size=(3, 2)),) for _ in range(10)],
(
[
[4, 7],
[6, 1],
[2, 5],
],
),
[
[4, 14],
[12, 1],
[6, 5],
],
),
pytest.param(
lambda x: LookupTable([2, 1, 3, 0])[x],
{
"x": EncryptedTensor(UnsignedInteger(2), shape=(3, 2)),
},
[(numpy.random.randint(0, 2 ** 2, size=(3, 2)),) for _ in range(10)],
(
[
[0, 1],
[2, 1],
[3, 0],
],
),
[
[2, 1],
[3, 1],
[0, 2],
],
),
],
)
def test_compile_and_run_tensor_correctness(
function, parameters, inputset, test_input, expected_output, default_compilation_configuration
):
"""Test correctness of results when running a compiled function with tensor operators"""
circuit = compile_numpy_function(
function,
parameters,
inputset,
default_compilation_configuration,
)
numpy_test_input = (numpy.array(item, dtype=numpy.uint8) for item in test_input)
assert numpy.array_equal(
circuit.run(*numpy_test_input),
numpy.array(expected_output, dtype=numpy.uint8),
)
@pytest.mark.parametrize(
"size, input_range",
[
@@ -769,71 +991,7 @@ function you are trying to compile isn't supported for MLIR lowering
%0 = Constant(1) # ClearScalar<Integer<unsigned, 1 bits>>
%1 = x # EncryptedScalar<Integer<unsigned, 3 bits>>
%2 = Sub(%0, %1) # EncryptedScalar<Integer<signed, 4 bits>>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only scalar unsigned integer outputs are supported
return(%2)
""".lstrip() # noqa: E501
),
),
pytest.param(
lambda x: x + 1,
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(2, 2))},
[(numpy.random.randint(0, 8, size=(2, 2)),) for i in range(10)],
(
"""
function you are trying to compile isn't supported for MLIR lowering
%0 = x # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(2, 2)>
%1 = Constant(1) # ClearScalar<Integer<unsigned, 1 bits>>
%2 = Add(%0, %1) # EncryptedTensor<Integer<unsigned, 4 bits>, shape=(2, 2)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only scalar addition is supported
return(%2)
""".lstrip() # noqa: E501
),
),
pytest.param(
lambda x: x + 1,
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(2, 2))},
[(numpy.random.randint(0, 2 ** 3, size=(2, 2)),) for i in range(10)],
(
"""
function you are trying to compile isn't supported for MLIR lowering
%0 = x # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(2, 2)>
%1 = Constant(1) # ClearScalar<Integer<unsigned, 1 bits>>
%2 = Add(%0, %1) # EncryptedTensor<Integer<unsigned, 4 bits>, shape=(2, 2)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only scalar addition is supported
return(%2)
""".lstrip() # noqa: E501
),
),
pytest.param(
lambda x: x * 1,
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(2, 2))},
[(numpy.random.randint(0, 2 ** 3, size=(2, 2)),) for i in range(10)],
(
"""
function you are trying to compile isn't supported for MLIR lowering
%0 = x # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(2, 2)>
%1 = Constant(1) # ClearScalar<Integer<unsigned, 1 bits>>
%2 = Mul(%0, %1) # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(2, 2)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only scalar multiplication is supported
return(%2)
""".lstrip() # noqa: E501
),
),
pytest.param(
lambda x: 127 - x,
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(2, 2))},
[(numpy.random.randint(0, 2 ** 3, size=(2, 2)),) for i in range(10)],
(
"""
function you are trying to compile isn't supported for MLIR lowering
%0 = Constant(127) # ClearScalar<Integer<unsigned, 7 bits>>
%1 = x # EncryptedTensor<Integer<unsigned, 3 bits>, shape=(2, 2)>
%2 = Sub(%0, %1) # EncryptedTensor<Integer<unsigned, 7 bits>, shape=(2, 2)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only scalar subtraction is supported
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer outputs are supported
return(%2)
""".lstrip() # noqa: E501
),
@@ -892,22 +1050,27 @@ return(%1)
[(numpy.array(i), numpy.array(i)) for i in range(10)],
(
"""
function you are trying to compile isn't supported for MLIR lowering\n
%0 = x # EncryptedScalar<Integer<unsigned, 4 bits>>
%1 = Constant(2.8) # ClearScalar<Float<64 bits>>
function you are trying to compile isn't supported for MLIR lowering
%0 = Constant(1.5) # ClearScalar<Float<64 bits>>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer constants are supported
%2 = y # EncryptedScalar<Integer<unsigned, 4 bits>>
%3 = Constant(9.3) # ClearScalar<Float<64 bits>>
%1 = x # EncryptedScalar<Integer<unsigned, 4 bits>>
%2 = Constant(2.8) # ClearScalar<Float<64 bits>>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer constants are supported
%4 = Add(%0, %1) # EncryptedScalar<Float<64 bits>>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer intermediates are supported
%5 = Add(%2, %3) # EncryptedScalar<Float<64 bits>>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer intermediates are supported
%6 = Add(%4, %5) # EncryptedScalar<Float<64 bits>>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer intermediates are supported
%7 = astype(int32)(%6) # EncryptedScalar<Integer<unsigned, 5 bits>>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer scalar lookup tables are supported
return(%7)
%3 = y # EncryptedScalar<Integer<unsigned, 4 bits>>
%4 = Constant(9.3) # ClearScalar<Float<64 bits>>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer constants are supported
%5 = Add(%1, %2) # EncryptedScalar<Float<64 bits>>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer addition is supported
%6 = Add(%3, %4) # EncryptedScalar<Float<64 bits>>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer addition is supported
%7 = Sub(%5, %6) # EncryptedScalar<Float<64 bits>>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer subtraction is supported
%8 = Mul(%7, %0) # EncryptedScalar<Float<64 bits>>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only integer multiplication is supported
%9 = astype(int32)(%8) # EncryptedScalar<Integer<signed, 5 bits>>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ astype(int32) is not supported for the time being
return(%9)
""".lstrip() # noqa: E501
),
),
@@ -1057,7 +1220,7 @@ function you are trying to compile isn't supported for MLIR lowering
%3 = np.negative(%2) # EncryptedScalar<Integer<signed, 3 bits>>
%4 = Mul(%3, %1) # EncryptedScalar<Integer<signed, 6 bits>>
%5 = np.absolute(%4) # EncryptedScalar<Integer<unsigned, 5 bits>>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer scalar lookup tables are supported
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ np.absolute is not supported for the time being
%6 = astype(int32)(%5) # EncryptedScalar<Integer<unsigned, 5 bits>>
%7 = Add(%6, %0) # EncryptedScalar<Integer<unsigned, 6 bits>>
return(%7)
@@ -1255,7 +1418,7 @@ function you are trying to compile isn't supported for MLIR lowering
%0 = x # EncryptedScalar<Integer<unsigned, 4 bits>>
%1 = Constant(-3) # ClearScalar<Integer<signed, 3 bits>>
%2 = Add(%0, %1) # EncryptedScalar<Integer<signed, 4 bits>>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only scalar unsigned integer outputs are supported
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only unsigned integer outputs are supported
return(%2)
""".lstrip() # noqa: E501 # pylint: disable=line-too-long
)