From bd8dca11d5677439753ec17518ceec11a92557ac Mon Sep 17 00:00:00 2001 From: youben11 Date: Tue, 14 Sep 2021 16:26:21 +0100 Subject: [PATCH] fix(TLU): extend TLU to 2 ** bit_width elements --- concrete/common/mlir/utils.py | 19 ++++++++++++++++ concrete/numpy/compile.py | 4 ++++ tests/common/mlir/test_mlir_converter.py | 29 ++++++++++++++++++++++++ 3 files changed, 52 insertions(+) diff --git a/concrete/common/mlir/utils.py b/concrete/common/mlir/utils.py index c6dec7c8d..60779a3d7 100644 --- a/concrete/common/mlir/utils.py +++ b/concrete/common/mlir/utils.py @@ -10,6 +10,7 @@ from ..data_types.dtypes_helpers import ( value_is_scalar_integer, ) from ..operator_graph import OPGraph +from ..representation.intermediate import ArbitraryFunction def is_graph_values_compatible_with_mlir(op_graph: OPGraph) -> bool: @@ -63,3 +64,21 @@ def update_bit_width_for_mlir(op_graph: OPGraph): ): max_bit_width = max(max_bit_width, value_out.data_type.bit_width) _set_all_bit_width(op_graph, max_bit_width) + + +def extend_direct_lookup_tables(op_graph: OPGraph): + """Extend direct lookup tables to the maximum length the input bit width can support. + + Args: + op_graph: graph to update lookup tables for + """ + for node in op_graph.graph.nodes: + if isinstance(node, ArbitraryFunction) and node.op_name == "TLU": + table = node.op_kwargs["table"] + bit_width = cast(Integer, node.inputs[0].data_type).bit_width + expected_length = 2 ** bit_width + if len(table) > expected_length: + node.op_kwargs["table"] = table[:expected_length] + else: + repeat = expected_length // len(table) + node.op_kwargs["table"] = (table * repeat)[:expected_length] diff --git a/concrete/numpy/compile.py b/concrete/numpy/compile.py index fd9cd6359..ea3f964a9 100644 --- a/concrete/numpy/compile.py +++ b/concrete/numpy/compile.py @@ -11,6 +11,7 @@ from ..common.common_helpers import check_op_graph_is_integer_program from ..common.compilation import CompilationArtifacts, CompilationConfiguration from ..common.mlir import V0_OPSET_CONVERSION_FUNCTIONS, MLIRConverter from ..common.mlir.utils import ( + extend_direct_lookup_tables, is_graph_values_compatible_with_mlir, update_bit_width_for_mlir, ) @@ -133,6 +134,9 @@ def _compile_numpy_function_into_op_graph_internal( # Update bit_width for MLIR update_bit_width_for_mlir(op_graph) + # TODO: workaround extend LUT #359 + extend_direct_lookup_tables(op_graph) + return op_graph diff --git a/tests/common/mlir/test_mlir_converter.py b/tests/common/mlir/test_mlir_converter.py index f1378a0b8..5769dffda 100644 --- a/tests/common/mlir/test_mlir_converter.py +++ b/tests/common/mlir/test_mlir_converter.py @@ -67,6 +67,20 @@ def lut(x): return table[x] +# TODO: remove workaround #359 +def lut_more_bits_than_table_length(x, y): + """Test lookup table when bit_width support longer LUT""" + table = LookupTable([3, 6, 0, 2, 1, 4, 5, 7]) + return table[x] + y + + +# TODO: remove workaround #359 +def lut_less_bits_than_table_length(x): + """Test lookup table when bit_width support smaller LUT""" + table = LookupTable([3, 6, 0, 2, 1, 4, 5, 7, 3, 6, 0, 2, 1, 4, 5, 7]) + return table[x] + + def dot(x, y): """Test dot""" return numpy.dot(x, y) @@ -184,6 +198,21 @@ def datagen(*args): }, (range(0, 8),), ), + ( + lut_more_bits_than_table_length, + { + "x": EncryptedScalar(Integer(64, is_signed=False)), + "y": EncryptedScalar(Integer(64, is_signed=False)), + }, + (range(0, 8), range(0, 16)), + ), + ( + lut_less_bits_than_table_length, + { + "x": EncryptedScalar(Integer(64, is_signed=False)), + }, + (range(0, 8),), + ), ( dot, {