feat(mlir): implement MLIR conversion of multi lookup tables

This commit is contained in:
Umut
2021-11-17 11:12:35 +03:00
parent 78f82fb9a1
commit 1d691f232b
3 changed files with 58 additions and 31 deletions

View File

@@ -1049,6 +1049,34 @@ def test_compile_and_run_lut_correctness(
check_is_good_execution(compiler_engine, function, args)
def test_compile_and_run_multi_lut_correctness(default_compilation_configuration):
"""Test correctness of results when running a compiled function with Multi LUT"""
def function_to_compile(x):
table = MultiLookupTable(
[
[LookupTable([1, 2, 1, 0]), LookupTable([2, 2, 1, 3])],
[LookupTable([1, 0, 1, 0]), LookupTable([0, 2, 3, 3])],
[LookupTable([0, 2, 3, 0]), LookupTable([2, 1, 2, 0])],
]
)
return table[x]
compiler_engine = compile_numpy_function(
function_to_compile,
{
"x": EncryptedTensor(UnsignedInteger(2), shape=(3, 2)),
},
[(numpy.random.randint(0, 2 ** 2, size=(3, 2)),) for _ in range(10)],
default_compilation_configuration,
)
# testing random values
for _ in range(10):
args = [numpy.random.randint(0, 2 ** 2, size=(3, 2), dtype=numpy.uint8)]
check_is_good_execution(compiler_engine, function_to_compile, args)
def test_compile_function_with_direct_tlu(default_compilation_configuration):
"""Test compile_numpy_function_into_op_graph for a program with direct table lookup"""
@@ -1225,23 +1253,6 @@ return %2
""".strip() # noqa: E501
),
),
pytest.param(
multi_lut,
{"x": EncryptedTensor(UnsignedInteger(2), shape=(3, 2))},
[(numpy.random.randint(0, 2 ** 2, size=(3, 2)),) for _ in range(32)],
(
"""
function you are trying to compile isn't supported for MLIR lowering
%0 = x # EncryptedTensor<uint2, shape=(3, 2)>
%1 = MultiTLU(%0, input_shape=(3, 2), tables=[[[1, 2, 1 ... 1, 2, 0]]]) # EncryptedTensor<uint2, shape=(3, 2)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ direct multi table lookup is not supported for the time being
return %1
""".strip() # noqa: E501
),
),
pytest.param(
lambda x: numpy.transpose(x),
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(3, 2))},