mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
refactor: factorize multi lookup tables
This commit is contained in:
@@ -293,29 +293,38 @@ class IntermediateNodeConverter:
|
||||
tables = additional_conversion_info["tables"][self.node]
|
||||
assert_true(len(tables) > 0)
|
||||
|
||||
lut_shape: Tuple[int, ...] = ()
|
||||
map_shape: Tuple[int, ...] = ()
|
||||
|
||||
if len(tables) == 1:
|
||||
table = tables[0][0]
|
||||
|
||||
lut_shape: Tuple[int, ...] = (len(table),)
|
||||
|
||||
# The reduction on 63b is to avoid problems like doing a TLU of
|
||||
# the form T[j] = 2<<j, for j which is supposed to be 7b as per
|
||||
# constraint of the compiler, while in practice, it is a small
|
||||
# value. Reducing on 64b was not ok for some reason
|
||||
lut_shape = (len(table),)
|
||||
lut_values = numpy.array(table % (2 << 63), dtype=numpy.uint64)
|
||||
|
||||
map_shape = ()
|
||||
map_values = None
|
||||
else:
|
||||
assert_true(isinstance(output, TensorValue))
|
||||
assert isinstance(output, TensorValue)
|
||||
|
||||
individual_table_size = len(tables[0][0])
|
||||
lut_shape = (*output.shape, individual_table_size)
|
||||
|
||||
lut_shape = (len(tables), individual_table_size)
|
||||
map_shape = output.shape
|
||||
|
||||
lut_values = numpy.zeros(lut_shape, dtype=numpy.uint64)
|
||||
for table, indices in tables:
|
||||
map_values = numpy.zeros(map_shape, dtype=numpy.intp)
|
||||
|
||||
for i, (table, indices) in enumerate(tables):
|
||||
assert_true(len(table) == individual_table_size)
|
||||
lut_values[i, :] = table
|
||||
for index in indices:
|
||||
index = (*index, slice(None, None, 1))
|
||||
lut_values[index] = table
|
||||
map_values[index] = i
|
||||
|
||||
lut_type = RankedTensorType.get(lut_shape, IntegerType.get_signless(64, context=self.ctx))
|
||||
lut_attr = DenseElementsAttr.get(lut_values, context=self.ctx)
|
||||
@@ -328,7 +337,19 @@ class IntermediateNodeConverter:
|
||||
if len(tables) == 1:
|
||||
result = fhelinalg.ApplyLookupTableEintOp(resulting_type, pred, lut).result
|
||||
else:
|
||||
result = fhelinalg.ApplyMultiLookupTableEintOp(resulting_type, pred, lut).result
|
||||
assert_true(map_shape != ())
|
||||
assert_true(map_values is not None)
|
||||
|
||||
index_type = IndexType.parse("index")
|
||||
map_type = RankedTensorType.get(map_shape, index_type)
|
||||
map_attr = DenseElementsAttr.get(map_values, context=self.ctx, type=index_type)
|
||||
|
||||
result = fhelinalg.ApplyMappedLookupTableEintOp(
|
||||
resulting_type,
|
||||
pred,
|
||||
lut,
|
||||
arith.ConstantOp(map_type, map_attr).result,
|
||||
).result
|
||||
else:
|
||||
result = fhe.ApplyLookupTableEintOp(resulting_type, pred, lut).result
|
||||
|
||||
|
||||
Reference in New Issue
Block a user