refactor: factorize multi lookup tables

This commit is contained in:
Umut
2022-01-05 10:02:30 +03:00
parent 73596b3b7d
commit c7b9380b4c

View File

@@ -293,29 +293,38 @@ class IntermediateNodeConverter:
tables = additional_conversion_info["tables"][self.node]
assert_true(len(tables) > 0)
lut_shape: Tuple[int, ...] = ()
map_shape: Tuple[int, ...] = ()
if len(tables) == 1:
table = tables[0][0]
lut_shape: Tuple[int, ...] = (len(table),)
# The reduction on 63b is to avoid problems like doing a TLU of
# the form T[j] = 2<<j, for j which is supposed to be 7b as per
# constraint of the compiler, while in practice, it is a small
# value. Reducing on 64b was not ok for some reason
lut_shape = (len(table),)
lut_values = numpy.array(table % (2 << 63), dtype=numpy.uint64)
map_shape = ()
map_values = None
else:
assert_true(isinstance(output, TensorValue))
assert isinstance(output, TensorValue)
individual_table_size = len(tables[0][0])
lut_shape = (*output.shape, individual_table_size)
lut_shape = (len(tables), individual_table_size)
map_shape = output.shape
lut_values = numpy.zeros(lut_shape, dtype=numpy.uint64)
for table, indices in tables:
map_values = numpy.zeros(map_shape, dtype=numpy.intp)
for i, (table, indices) in enumerate(tables):
assert_true(len(table) == individual_table_size)
lut_values[i, :] = table
for index in indices:
index = (*index, slice(None, None, 1))
lut_values[index] = table
map_values[index] = i
lut_type = RankedTensorType.get(lut_shape, IntegerType.get_signless(64, context=self.ctx))
lut_attr = DenseElementsAttr.get(lut_values, context=self.ctx)
@@ -328,7 +337,19 @@ class IntermediateNodeConverter:
if len(tables) == 1:
result = fhelinalg.ApplyLookupTableEintOp(resulting_type, pred, lut).result
else:
result = fhelinalg.ApplyMultiLookupTableEintOp(resulting_type, pred, lut).result
assert_true(map_shape != ())
assert_true(map_values is not None)
index_type = IndexType.parse("index")
map_type = RankedTensorType.get(map_shape, index_type)
map_attr = DenseElementsAttr.get(map_values, context=self.ctx, type=index_type)
result = fhelinalg.ApplyMappedLookupTableEintOp(
resulting_type,
pred,
lut,
arith.ConstantOp(map_type, map_attr).result,
).result
else:
result = fhe.ApplyLookupTableEintOp(resulting_type, pred, lut).result