diff --git a/concrete/common/extensions/table.py b/concrete/common/extensions/table.py index 981ebe7ce..39651a8b4 100644 --- a/concrete/common/extensions/table.py +++ b/concrete/common/extensions/table.py @@ -62,7 +62,7 @@ class LookupTable: @staticmethod def _check_index_out_of_range(x, table): - if x < 0 or x >= len(table): + if not -len(table) <= x < len(table): raise ValueError( f"Lookup table with {len(table)} entries cannot be indexed with {x} " f"(you should check your inputset)", diff --git a/concrete/common/mlir/utils.py b/concrete/common/mlir/utils.py index 62c386774..1283963e0 100644 --- a/concrete/common/mlir/utils.py +++ b/concrete/common/mlir/utils.py @@ -1,9 +1,8 @@ """Utilities for MLIR conversion.""" -from typing import Dict, List, Optional, cast +from typing import Dict, List, Optional import networkx as nx -from ..data_types import Integer from ..data_types.dtypes_helpers import ( value_is_clear_scalar_integer, value_is_clear_tensor_integer, @@ -16,7 +15,7 @@ from ..debugging import format_operation_graph from ..debugging.custom_assert import assert_not_reached, assert_true from ..operator_graph import OPGraph from ..representation import intermediate -from ..representation.intermediate import GenericFunction, IntermediateNode +from ..representation.intermediate import IntermediateNode # TODO: should come from compiler, through an API, #402 ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB = 7 @@ -210,24 +209,3 @@ def update_bit_width_for_mlir(op_graph: OPGraph): ) _set_all_bit_width(op_graph, max_bit_width) - - -def extend_direct_lookup_tables(op_graph: OPGraph): - """Extend direct lookup tables to the maximum length the input bit width can support. - - Args: - op_graph: graph to update lookup tables for - """ - for node in op_graph.graph.nodes: - if isinstance(node, GenericFunction) and node.op_name == "TLU": - table = node.op_kwargs["table"] - bit_width = cast(Integer, node.inputs[0].dtype).bit_width - expected_length = 2 ** bit_width - - # TODO: remove no cover once the table length workaround is removed - # (https://github.com/zama-ai/concretefhe-internal/issues/359) - if len(table) > expected_length: # pragma: no cover - node.op_kwargs["table"] = table[:expected_length] - else: - repeat = expected_length // len(table) - node.op_kwargs["table"] = (table * repeat)[:expected_length] diff --git a/concrete/numpy/compile.py b/concrete/numpy/compile.py index b26120ece..f2ec41ad7 100644 --- a/concrete/numpy/compile.py +++ b/concrete/numpy/compile.py @@ -17,7 +17,6 @@ from ..common.debugging.custom_assert import assert_true from ..common.fhe_circuit import FHECircuit from ..common.mlir.utils import ( check_graph_values_compatibility_with_mlir, - extend_direct_lookup_tables, update_bit_width_for_mlir, ) from ..common.operator_graph import OPGraph @@ -601,9 +600,6 @@ def prepare_op_graph_for_mlir(op_graph: OPGraph): # Update bit_width for MLIR update_bit_width_for_mlir(op_graph) - # TODO: workaround extend LUT #359 - extend_direct_lookup_tables(op_graph) - # HACK # TODO: remove this ugly hack when https://github.com/zama-ai/concretefhe-internal/issues/1001 # is done diff --git a/tests/numpy/test_compile.py b/tests/numpy/test_compile.py index 5a242e7bc..73eee9005 100644 --- a/tests/numpy/test_compile.py +++ b/tests/numpy/test_compile.py @@ -50,6 +50,42 @@ def identity_lut_generator(n): return lambda x: LookupTable(list(range(2 ** n)))[x] +def negative_identity_smaller_lut_generator(n): + """Test negative lookup table""" + + table = LookupTable(range(2 ** (n - 1))) + offset = 2 ** (n - 1) + + return (lambda x: table[x + (-offset)]), table + + +def negative_identity_lut_generator(n): + """Test negative lookup table (bigger than bit-width)""" + + table = LookupTable(range(2 ** n)) + offset = 2 ** (n - 1) + + return (lambda x: table[x + (-offset)]), table + + +def negative_identity_bigger_lut_generator(n): + """Test negative lookup table (bigger than bit-width)""" + + table = LookupTable(range(2 ** (n + 1))) + offset = 2 ** (n - 1) + + return (lambda x: table[x + (-offset)]), table + + +def weird_lut(n): + """A weird lookup table to test an edge case related to negative indexing""" + + table = LookupTable([0, 1, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 4, 5, 6, 7]) + offset = 2 ** (n - 1) + + return (lambda x: table[x + (-offset)]), table + + def random_lut_1b(x): """1-bit random table lookup""" @@ -1375,6 +1411,45 @@ def test_compile_and_run_lut_correctness( check_is_good_execution(compiler_engine, function, args) +@pytest.mark.parametrize( + "function,table,bit_width", + [ + pytest.param(*negative_identity_smaller_lut_generator(n), n, id=f"smaller ({n}-bit)") + for n in range(1, 8) + ] + + [ + pytest.param(*negative_identity_lut_generator(n), n, id=f"normal ({n}-bit)") + for n in range(1, 8) + ] + + [ + pytest.param(*negative_identity_bigger_lut_generator(n), n, id=f"bigger ({n}-bit)") + for n in range(1, 7) + ] + + [ + pytest.param(*weird_lut(3), 3, id="weird"), + ], +) +def test_compile_and_run_negative_lut_correctness( + function, + table, + bit_width, + default_compilation_configuration, +): + """Test correctness when running a compiled function with LUT using negative values""" + + circuit = compile_numpy_function( + function, + {"x": EncryptedScalar(UnsignedInteger(bit_width))}, + range(2 ** bit_width), + default_compilation_configuration, + ) + + offset = 2 ** (bit_width - 1) + for value in range(-offset, offset): + assert table[value] == function(value + offset) + check_is_good_execution(circuit, function, [value + offset]) + + def test_compile_and_run_multi_lut_correctness(default_compilation_configuration): """Test correctness of results when running a compiled function with Multi LUT"""