mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat(mlir): implement MLIR conversion of multi lookup tables
This commit is contained in:
@@ -4,7 +4,7 @@
|
||||
|
||||
# pylint: disable=no-name-in-module
|
||||
|
||||
from typing import Any, Dict, List, cast
|
||||
from typing import Any, Dict, List, Tuple, cast
|
||||
|
||||
import numpy
|
||||
from mlir.dialects import arith
|
||||
@@ -231,6 +231,7 @@ class IntermediateNodeConverter:
|
||||
variable_input_index = variable_input_indices[0]
|
||||
|
||||
assert_true(len(self.node.outputs) == 1)
|
||||
output = self.node.outputs[0]
|
||||
|
||||
value = self.node.inputs[variable_input_index]
|
||||
assert_true(value.is_encrypted)
|
||||
@@ -241,25 +242,39 @@ class IntermediateNodeConverter:
|
||||
raise NotImplementedError(f"Table lookup on {value} cannot be converted to MLIR yet")
|
||||
|
||||
tables = additional_conversion_info["tables"][self.node]
|
||||
assert_true(len(tables) > 0)
|
||||
|
||||
# TODO: #559 adapt the code to support multi TLUs
|
||||
# This cannot be reached today as compilation fails
|
||||
# if the intermediate values are not all scalars
|
||||
if len(tables) > 1: # pragma: no cover
|
||||
raise NotImplementedError("Multi table lookups cannot be converted to MLIR yet")
|
||||
if len(tables) == 1:
|
||||
table = tables[0][0]
|
||||
|
||||
table = tables[0][0]
|
||||
lut_shape: Tuple[int, ...] = (len(table),)
|
||||
lut_values = numpy.array(table, dtype=numpy.uint64)
|
||||
else:
|
||||
assert_true(isinstance(output, TensorValue))
|
||||
assert isinstance(output, TensorValue)
|
||||
|
||||
lut_size = len(table)
|
||||
lut_type = RankedTensorType.get([lut_size], IntegerType.get_signless(64, context=self.ctx))
|
||||
lut_attr = DenseElementsAttr.get(numpy.array(table, dtype=numpy.uint64), context=self.ctx)
|
||||
individual_table_size = len(tables[0][0])
|
||||
lut_shape = (*output.shape, individual_table_size)
|
||||
|
||||
lut_values = numpy.zeros(lut_shape, dtype=numpy.uint64)
|
||||
for table, indices in tables:
|
||||
assert_true(len(table) == individual_table_size)
|
||||
for index in indices:
|
||||
index = (*index, slice(None, None, 1))
|
||||
lut_values[index] = table
|
||||
|
||||
lut_type = RankedTensorType.get(lut_shape, IntegerType.get_signless(64, context=self.ctx))
|
||||
lut_attr = DenseElementsAttr.get(lut_values, context=self.ctx)
|
||||
lut = arith.ConstantOp(lut_type, lut_attr).result
|
||||
|
||||
resulting_type = value_to_mlir_type(self.ctx, self.node.outputs[0])
|
||||
resulting_type = value_to_mlir_type(self.ctx, output)
|
||||
pred = self.preds[variable_input_index]
|
||||
|
||||
if self.one_of_the_inputs_is_a_tensor:
|
||||
result = hlfhelinalg.ApplyLookupTableEintOp(resulting_type, pred, lut).result
|
||||
if len(tables) == 1:
|
||||
result = hlfhelinalg.ApplyLookupTableEintOp(resulting_type, pred, lut).result
|
||||
else:
|
||||
result = hlfhelinalg.ApplyMultiLookupTableEintOp(resulting_type, pred, lut).result
|
||||
else:
|
||||
result = hlfhe.ApplyLookupTableEintOp(resulting_type, pred, lut).result
|
||||
|
||||
|
||||
@@ -89,14 +89,15 @@ def check_node_compatibility_with_mlir(
|
||||
)
|
||||
== 1
|
||||
)
|
||||
if node.op_name == "MultiTLU":
|
||||
return "direct multi table lookup is not supported for the time being"
|
||||
|
||||
if not value_is_unsigned_integer(inputs[0]):
|
||||
# this branch is not reachable because compilation fails during inputset evaluation
|
||||
if node.op_name == "TLU": # pragma: no cover
|
||||
return "only unsigned integer lookup tables are supported"
|
||||
|
||||
if node.op_name == "MultiTLU": # pragma: no cover
|
||||
return "only unsigned integer multi lookup tables are supported"
|
||||
|
||||
# e.g., `np.absolute is not supported for the time being`
|
||||
return f"{node.op_name} is not supported for the time being"
|
||||
else:
|
||||
|
||||
@@ -1049,6 +1049,34 @@ def test_compile_and_run_lut_correctness(
|
||||
check_is_good_execution(compiler_engine, function, args)
|
||||
|
||||
|
||||
def test_compile_and_run_multi_lut_correctness(default_compilation_configuration):
|
||||
"""Test correctness of results when running a compiled function with Multi LUT"""
|
||||
|
||||
def function_to_compile(x):
|
||||
table = MultiLookupTable(
|
||||
[
|
||||
[LookupTable([1, 2, 1, 0]), LookupTable([2, 2, 1, 3])],
|
||||
[LookupTable([1, 0, 1, 0]), LookupTable([0, 2, 3, 3])],
|
||||
[LookupTable([0, 2, 3, 0]), LookupTable([2, 1, 2, 0])],
|
||||
]
|
||||
)
|
||||
return table[x]
|
||||
|
||||
compiler_engine = compile_numpy_function(
|
||||
function_to_compile,
|
||||
{
|
||||
"x": EncryptedTensor(UnsignedInteger(2), shape=(3, 2)),
|
||||
},
|
||||
[(numpy.random.randint(0, 2 ** 2, size=(3, 2)),) for _ in range(10)],
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
# testing random values
|
||||
for _ in range(10):
|
||||
args = [numpy.random.randint(0, 2 ** 2, size=(3, 2), dtype=numpy.uint8)]
|
||||
check_is_good_execution(compiler_engine, function_to_compile, args)
|
||||
|
||||
|
||||
def test_compile_function_with_direct_tlu(default_compilation_configuration):
|
||||
"""Test compile_numpy_function_into_op_graph for a program with direct table lookup"""
|
||||
|
||||
@@ -1225,23 +1253,6 @@ return %2
|
||||
""".strip() # noqa: E501
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
multi_lut,
|
||||
{"x": EncryptedTensor(UnsignedInteger(2), shape=(3, 2))},
|
||||
[(numpy.random.randint(0, 2 ** 2, size=(3, 2)),) for _ in range(32)],
|
||||
(
|
||||
"""
|
||||
|
||||
function you are trying to compile isn't supported for MLIR lowering
|
||||
|
||||
%0 = x # EncryptedTensor<uint2, shape=(3, 2)>
|
||||
%1 = MultiTLU(%0, input_shape=(3, 2), tables=[[[1, 2, 1 ... 1, 2, 0]]]) # EncryptedTensor<uint2, shape=(3, 2)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ direct multi table lookup is not supported for the time being
|
||||
return %1
|
||||
|
||||
""".strip() # noqa: E501
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: numpy.transpose(x),
|
||||
{"x": EncryptedTensor(Integer(3, is_signed=False), shape=(3, 2))},
|
||||
|
||||
Reference in New Issue
Block a user