workaround: make compilation work if table output is small

this workaround will be removed once it is managed by the compiler.
closes #279
refs #412
This commit is contained in:
Benoit Chevallier-Mames
2021-09-17 17:03:50 +02:00
committed by Benoit Chevallier
parent b950bb4459
commit 0a6ebf3b19
2 changed files with 20 additions and 0 deletions

View File

@@ -79,6 +79,13 @@ def update_bit_width_for_mlir(op_graph: OPGraph):
if current_node_out_bit_width > ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB:
offending_list.append((node, current_node_out_bit_width))
# TODO: remove this workaround, which was for #279, once the compiler can handle
# smaller tables, #412
has_a_table = any(isinstance(node, ArbitraryFunction) for node in op_graph.graph.nodes)
if has_a_table:
max_bit_width = ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB
_set_all_bit_width(op_graph, max_bit_width)
# Check that the max_bit_width is supported by the compiler

View File

@@ -30,6 +30,17 @@ def lut(x):
return table[x]
def small_lut(x):
"""Test lookup table with small size and output"""
table = LookupTable(list(range(32)))
return table[x]
def small_fused_table(x):
"""Test with a small fused table"""
return (10 * (numpy.cos(x + 1) + 1)).astype(numpy.uint32)
@pytest.mark.parametrize(
"function,input_ranges,list_of_arg_names",
[
@@ -84,6 +95,8 @@ def test_compile_function_multiple_outputs(function, input_ranges, list_of_arg_n
pytest.param(lambda x: 8 - x, ((0, 2),), ["x"]),
pytest.param(lambda x, y: x + y + 8, ((2, 10), (4, 8)), ["x", "y"]),
pytest.param(lut, ((0, 127),), ["x"]),
pytest.param(small_lut, ((0, 31),), ["x"]),
pytest.param(small_fused_table, ((0, 31),), ["x"]),
],
)
def test_compile_and_run_function_multiple_outputs(function, input_ranges, list_of_arg_names):