diff --git a/concrete/common/mlir/utils.py b/concrete/common/mlir/utils.py index 9cfae6597..97739a49a 100644 --- a/concrete/common/mlir/utils.py +++ b/concrete/common/mlir/utils.py @@ -79,6 +79,13 @@ def update_bit_width_for_mlir(op_graph: OPGraph): if current_node_out_bit_width > ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB: offending_list.append((node, current_node_out_bit_width)) + # TODO: remove this workaround, which was for #279, once the compiler can handle + # smaller tables, #412 + has_a_table = any(isinstance(node, ArbitraryFunction) for node in op_graph.graph.nodes) + + if has_a_table: + max_bit_width = ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB + _set_all_bit_width(op_graph, max_bit_width) # Check that the max_bit_width is supported by the compiler diff --git a/tests/numpy/test_compile.py b/tests/numpy/test_compile.py index 14be89ef1..3919b6e8f 100644 --- a/tests/numpy/test_compile.py +++ b/tests/numpy/test_compile.py @@ -30,6 +30,17 @@ def lut(x): return table[x] +def small_lut(x): + """Test lookup table with small size and output""" + table = LookupTable(list(range(32))) + return table[x] + + +def small_fused_table(x): + """Test with a small fused table""" + return (10 * (numpy.cos(x + 1) + 1)).astype(numpy.uint32) + + @pytest.mark.parametrize( "function,input_ranges,list_of_arg_names", [ @@ -84,6 +95,8 @@ def test_compile_function_multiple_outputs(function, input_ranges, list_of_arg_n pytest.param(lambda x: 8 - x, ((0, 2),), ["x"]), pytest.param(lambda x, y: x + y + 8, ((2, 10), (4, 8)), ["x", "y"]), pytest.param(lut, ((0, 127),), ["x"]), + pytest.param(small_lut, ((0, 31),), ["x"]), + pytest.param(small_fused_table, ((0, 31),), ["x"]), ], ) def test_compile_and_run_function_multiple_outputs(function, input_ranges, list_of_arg_names):