fix(TLU): extend TLU to 2 ** bit_width elements

This commit is contained in:
youben11
2021-09-14 16:26:21 +01:00
committed by Ayoub Benaissa
parent 5e8a7c527b
commit bd8dca11d5
3 changed files with 52 additions and 0 deletions

View File

@@ -10,6 +10,7 @@ from ..data_types.dtypes_helpers import (
value_is_scalar_integer,
)
from ..operator_graph import OPGraph
from ..representation.intermediate import ArbitraryFunction
def is_graph_values_compatible_with_mlir(op_graph: OPGraph) -> bool:
@@ -63,3 +64,21 @@ def update_bit_width_for_mlir(op_graph: OPGraph):
):
max_bit_width = max(max_bit_width, value_out.data_type.bit_width)
_set_all_bit_width(op_graph, max_bit_width)
def extend_direct_lookup_tables(op_graph: OPGraph):
"""Extend direct lookup tables to the maximum length the input bit width can support.
Args:
op_graph: graph to update lookup tables for
"""
for node in op_graph.graph.nodes:
if isinstance(node, ArbitraryFunction) and node.op_name == "TLU":
table = node.op_kwargs["table"]
bit_width = cast(Integer, node.inputs[0].data_type).bit_width
expected_length = 2 ** bit_width
if len(table) > expected_length:
node.op_kwargs["table"] = table[:expected_length]
else:
repeat = expected_length // len(table)
node.op_kwargs["table"] = (table * repeat)[:expected_length]

View File

@@ -11,6 +11,7 @@ from ..common.common_helpers import check_op_graph_is_integer_program
from ..common.compilation import CompilationArtifacts, CompilationConfiguration
from ..common.mlir import V0_OPSET_CONVERSION_FUNCTIONS, MLIRConverter
from ..common.mlir.utils import (
extend_direct_lookup_tables,
is_graph_values_compatible_with_mlir,
update_bit_width_for_mlir,
)
@@ -133,6 +134,9 @@ def _compile_numpy_function_into_op_graph_internal(
# Update bit_width for MLIR
update_bit_width_for_mlir(op_graph)
# TODO: workaround extend LUT #359
extend_direct_lookup_tables(op_graph)
return op_graph

View File

@@ -67,6 +67,20 @@ def lut(x):
return table[x]
# TODO: remove workaround #359
def lut_more_bits_than_table_length(x, y):
"""Test lookup table when bit_width support longer LUT"""
table = LookupTable([3, 6, 0, 2, 1, 4, 5, 7])
return table[x] + y
# TODO: remove workaround #359
def lut_less_bits_than_table_length(x):
"""Test lookup table when bit_width support smaller LUT"""
table = LookupTable([3, 6, 0, 2, 1, 4, 5, 7, 3, 6, 0, 2, 1, 4, 5, 7])
return table[x]
def dot(x, y):
"""Test dot"""
return numpy.dot(x, y)
@@ -184,6 +198,21 @@ def datagen(*args):
},
(range(0, 8),),
),
(
lut_more_bits_than_table_length,
{
"x": EncryptedScalar(Integer(64, is_signed=False)),
"y": EncryptedScalar(Integer(64, is_signed=False)),
},
(range(0, 8), range(0, 16)),
),
(
lut_less_bits_than_table_length,
{
"x": EncryptedScalar(Integer(64, is_signed=False)),
},
(range(0, 8),),
),
(
dot,
{