From 1d691f232b7e99e2ee32ab0a83f2fcbb5aa76e39 Mon Sep 17 00:00:00 2001 From: Umut Date: Wed, 17 Nov 2021 11:12:35 +0300 Subject: [PATCH] feat(mlir): implement MLIR conversion of multi lookup tables --- concrete/common/mlir/node_converter.py | 39 +++++++++++++++------- concrete/common/mlir/utils.py | 5 +-- tests/numpy/test_compile.py | 45 ++++++++++++++++---------- 3 files changed, 58 insertions(+), 31 deletions(-) diff --git a/concrete/common/mlir/node_converter.py b/concrete/common/mlir/node_converter.py index 3b10fb23b..030779991 100644 --- a/concrete/common/mlir/node_converter.py +++ b/concrete/common/mlir/node_converter.py @@ -4,7 +4,7 @@ # pylint: disable=no-name-in-module -from typing import Any, Dict, List, cast +from typing import Any, Dict, List, Tuple, cast import numpy from mlir.dialects import arith @@ -231,6 +231,7 @@ class IntermediateNodeConverter: variable_input_index = variable_input_indices[0] assert_true(len(self.node.outputs) == 1) + output = self.node.outputs[0] value = self.node.inputs[variable_input_index] assert_true(value.is_encrypted) @@ -241,25 +242,39 @@ class IntermediateNodeConverter: raise NotImplementedError(f"Table lookup on {value} cannot be converted to MLIR yet") tables = additional_conversion_info["tables"][self.node] + assert_true(len(tables) > 0) - # TODO: #559 adapt the code to support multi TLUs - # This cannot be reached today as compilation fails - # if the intermediate values are not all scalars - if len(tables) > 1: # pragma: no cover - raise NotImplementedError("Multi table lookups cannot be converted to MLIR yet") + if len(tables) == 1: + table = tables[0][0] - table = tables[0][0] + lut_shape: Tuple[int, ...] = (len(table),) + lut_values = numpy.array(table, dtype=numpy.uint64) + else: + assert_true(isinstance(output, TensorValue)) + assert isinstance(output, TensorValue) - lut_size = len(table) - lut_type = RankedTensorType.get([lut_size], IntegerType.get_signless(64, context=self.ctx)) - lut_attr = DenseElementsAttr.get(numpy.array(table, dtype=numpy.uint64), context=self.ctx) + individual_table_size = len(tables[0][0]) + lut_shape = (*output.shape, individual_table_size) + + lut_values = numpy.zeros(lut_shape, dtype=numpy.uint64) + for table, indices in tables: + assert_true(len(table) == individual_table_size) + for index in indices: + index = (*index, slice(None, None, 1)) + lut_values[index] = table + + lut_type = RankedTensorType.get(lut_shape, IntegerType.get_signless(64, context=self.ctx)) + lut_attr = DenseElementsAttr.get(lut_values, context=self.ctx) lut = arith.ConstantOp(lut_type, lut_attr).result - resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0]) + resulting_type = value_to_mlir_type(self.ctx, output) pred = self.preds[variable_input_index] if self.one_of_the_inputs_is_a_tensor: - result = hlfhelinalg.ApplyLookupTableEintOp(resulting_type, pred, lut).result + if len(tables) == 1: + result = hlfhelinalg.ApplyLookupTableEintOp(resulting_type, pred, lut).result + else: + result = hlfhelinalg.ApplyMultiLookupTableEintOp(resulting_type, pred, lut).result else: result = hlfhe.ApplyLookupTableEintOp(resulting_type, pred, lut).result diff --git a/concrete/common/mlir/utils.py b/concrete/common/mlir/utils.py index 406d034e6..653589f37 100644 --- a/concrete/common/mlir/utils.py +++ b/concrete/common/mlir/utils.py @@ -89,14 +89,15 @@ def check_node_compatibility_with_mlir( ) == 1 ) - if node.op_name == "MultiTLU": - return "direct multi table lookup is not supported for the time being" if not value_is_unsigned_integer(inputs[0]): # this branch is not reachable because compilation fails during inputset evaluation if node.op_name == "TLU": # pragma: no cover return "only unsigned integer lookup tables are supported" + if node.op_name == "MultiTLU": # pragma: no cover + return "only unsigned integer multi lookup tables are supported" + # e.g., `np.absolute is not supported for the time being` return f"{node.op_name} is not supported for the time being" else: diff --git a/tests/numpy/test_compile.py b/tests/numpy/test_compile.py index c5b6ed1f0..b33d3103d 100644 --- a/tests/numpy/test_compile.py +++ b/tests/numpy/test_compile.py @@ -1049,6 +1049,34 @@ def test_compile_and_run_lut_correctness( check_is_good_execution(compiler_engine, function, args) +def test_compile_and_run_multi_lut_correctness(default_compilation_configuration): + """Test correctness of results when running a compiled function with Multi LUT""" + + def function_to_compile(x): + table = MultiLookupTable( + [ + [LookupTable([1, 2, 1, 0]), LookupTable([2, 2, 1, 3])], + [LookupTable([1, 0, 1, 0]), LookupTable([0, 2, 3, 3])], + [LookupTable([0, 2, 3, 0]), LookupTable([2, 1, 2, 0])], + ] + ) + return table[x] + + compiler_engine = compile_numpy_function( + function_to_compile, + { + "x": EncryptedTensor(UnsignedInteger(2), shape=(3, 2)), + }, + [(numpy.random.randint(0, 2 ** 2, size=(3, 2)),) for _ in range(10)], + default_compilation_configuration, + ) + + # testing random values + for _ in range(10): + args = [numpy.random.randint(0, 2 ** 2, size=(3, 2), dtype=numpy.uint8)] + check_is_good_execution(compiler_engine, function_to_compile, args) + + def test_compile_function_with_direct_tlu(default_compilation_configuration): """Test compile_numpy_function_into_op_graph for a program with direct table lookup""" @@ -1225,23 +1253,6 @@ return %2 """.strip() # noqa: E501 ), ), - pytest.param( - multi_lut, - {"x": EncryptedTensor(UnsignedInteger(2), shape=(3, 2))}, - [(numpy.random.randint(0, 2 ** 2, size=(3, 2)),) for _ in range(32)], - ( - """ - -function you are trying to compile isn't supported for MLIR lowering - -%0 = x # EncryptedTensor -%1 = MultiTLU(%0, input_shape=(3, 2), tables=[[[1, 2, 1 ... 1, 2, 0]]]) # EncryptedTensor -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ direct multi table lookup is not supported for the time being -return %1 - - """.strip() # noqa: E501 - ), - ), pytest.param( lambda x: numpy.transpose(x), {"x": EncryptedTensor(Integer(3, is_signed=False), shape=(3, 2))},