mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat(extensions): create multi table lookup extension
This commit is contained in:
@@ -1,2 +1,2 @@
|
||||
"""Extensions module to provide additional functionality to our users."""
|
||||
from . import table
|
||||
from . import multi_table, table
|
||||
|
||||
222
concrete/common/extensions/multi_table.py
Normal file
222
concrete/common/extensions/multi_table.py
Normal file
@@ -0,0 +1,222 @@
|
||||
"""This file contains a wrapper class for direct multi table lookups."""
|
||||
|
||||
import itertools
|
||||
from copy import deepcopy
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
from ..data_types.base import BaseDataType
|
||||
from ..data_types.dtypes_helpers import find_type_to_hold_both_lossy
|
||||
from ..representation.intermediate import UnivariateFunction
|
||||
from ..tracing.base_tracer import BaseTracer
|
||||
from .table import LookupTable
|
||||
|
||||
|
||||
class MultiLookupTable:
|
||||
"""Class representing a multi lookup table."""
|
||||
|
||||
# Multi table lookup is needed when you want to perform a lookup on a tensor,
|
||||
# but you want each element to be used with a different lookup table.
|
||||
#
|
||||
# Here is an example:
|
||||
#
|
||||
# You have x which is of shape (2, 3),
|
||||
# you want the first row to be indexed with `table1 = LookupTable([2, 3, 1, 0])`
|
||||
# and the second row to be indexed with `table1 = LookupTable([0, 1, 3, 2])`
|
||||
#
|
||||
# You can create such a multi lookup table
|
||||
# multitable = MultiLookupTable(
|
||||
# [
|
||||
# [table1, table1, table1],
|
||||
# [table2, table2, table2],
|
||||
# ],
|
||||
# )
|
||||
# (notice the shape of multitable matches with the shape of x)
|
||||
#
|
||||
# and use multitable[x] toget the following result
|
||||
# assert multitable[x] == [
|
||||
# [table1[x[0, 0]], table1[x[0, 1]], table1[x[0, 2]]],
|
||||
# [table2[x[1, 0]], table2[x[1, 1]], table2[x[1, 2]]],
|
||||
# ]
|
||||
|
||||
# underlying lookup tables
|
||||
tables: List
|
||||
|
||||
# shape of the input of the lookup
|
||||
input_shape: Tuple[int, ...]
|
||||
|
||||
# type of the result of the lookup
|
||||
output_dtype: BaseDataType
|
||||
|
||||
def __init__(self, tables: List):
|
||||
input_shape_list: List[int] = []
|
||||
MultiLookupTable._extract_shape_using_first_elements_only(tables, input_shape_list)
|
||||
input_shape: Tuple[int, ...] = tuple(input_shape_list)
|
||||
|
||||
table_sizes: List[int] = []
|
||||
table_output_dtypes: List[BaseDataType] = []
|
||||
MultiLookupTable._check_shape_and_record_luts(
|
||||
tables,
|
||||
0,
|
||||
input_shape,
|
||||
table_sizes,
|
||||
table_output_dtypes,
|
||||
)
|
||||
|
||||
for i in range(1, len(table_sizes)):
|
||||
if table_sizes[i - 1] != table_sizes[i]:
|
||||
# this branch is for such a case:
|
||||
#
|
||||
# table1 = hnp.LookupTable([1, 3])
|
||||
# table2 = hnp.LookupTable([0, 2, 3, 1])
|
||||
#
|
||||
# multitable = hnp.MultiLookupTable(
|
||||
# [
|
||||
# [table1, table2, table1],
|
||||
# [table2, table1, table2],
|
||||
# ],
|
||||
# )
|
||||
raise ValueError(
|
||||
f"LookupTables within a MultiLookupTable "
|
||||
f"should have the same size but they do not "
|
||||
f"(there was a table with the size of {table_sizes[i - 1]} "
|
||||
f"and another with the size of {table_sizes[i]})"
|
||||
)
|
||||
|
||||
output_dtype = table_output_dtypes[0]
|
||||
for table_output_dtype in table_output_dtypes:
|
||||
output_dtype = find_type_to_hold_both_lossy(output_dtype, table_output_dtype)
|
||||
|
||||
self.tables = tables
|
||||
self.input_shape = input_shape
|
||||
self.output_dtype = output_dtype
|
||||
|
||||
def __getitem__(self, key: Union[int, BaseTracer]):
|
||||
# this branch is used during tracing and the regular flow is used during evaluation
|
||||
if isinstance(key, BaseTracer):
|
||||
traced_computation = UnivariateFunction(
|
||||
input_base_value=key.output,
|
||||
arbitrary_func=MultiLookupTable._checked_indexing,
|
||||
output_dtype=self.output_dtype,
|
||||
op_kwargs={
|
||||
"input_shape": deepcopy(self.input_shape),
|
||||
"tables": deepcopy(self.tables),
|
||||
},
|
||||
op_name="MultiTLU",
|
||||
)
|
||||
return key.__class__(
|
||||
inputs=[key],
|
||||
traced_computation=traced_computation,
|
||||
output_idx=0,
|
||||
)
|
||||
|
||||
# if not, it means table is indexed with a constant
|
||||
# thus, the result of the lookup is a constant
|
||||
# so, we can propagate it directly
|
||||
return MultiLookupTable._checked_indexing(key, self.input_shape, self.tables)
|
||||
|
||||
@staticmethod
|
||||
def _extract_shape_using_first_elements_only(array, shape):
|
||||
if not isinstance(array, list):
|
||||
# base case for recursion
|
||||
# the shape is already accumulated up to this point
|
||||
# so we just return
|
||||
return
|
||||
|
||||
if len(array) == 0:
|
||||
# this branch is for such a case:
|
||||
#
|
||||
# table1 = hnp.LookupTable([1, 3, 2, 0])
|
||||
# table2 = hnp.LookupTable([0, 2, 3, 1])
|
||||
#
|
||||
# multitable = hnp.MultiLookupTable(
|
||||
# [
|
||||
# [],
|
||||
# [table1, table2, table1],
|
||||
# [table2, table1, table2],
|
||||
# ],
|
||||
# )
|
||||
|
||||
raise ValueError("MultiLookupTable cannot have an empty array within it")
|
||||
|
||||
shape.append(len(array))
|
||||
MultiLookupTable._extract_shape_using_first_elements_only(array[0], shape)
|
||||
|
||||
@staticmethod
|
||||
def _check_shape_and_record_luts(array, dimension, shape, table_sizes, table_output_dtypes):
|
||||
if dimension == len(shape):
|
||||
if not isinstance(array, LookupTable):
|
||||
# this branch is for such a case:
|
||||
#
|
||||
# table1 = hnp.LookupTable([1, 3, 2, 0])
|
||||
# table2 = hnp.LookupTable([0, 2, 3, 1])
|
||||
#
|
||||
# multitable = hnp.MultiLookupTable(
|
||||
# [
|
||||
# [table1, table2, 4],
|
||||
# [table2, table1, table2],
|
||||
# ],
|
||||
# )
|
||||
raise ValueError(
|
||||
f"MultiLookupTable should have been made out of LookupTables "
|
||||
f"but it had an object of type {array.__class__.__name__} within it"
|
||||
)
|
||||
|
||||
table_sizes.append(len(array.table))
|
||||
table_output_dtypes.append(array.output_dtype)
|
||||
return
|
||||
|
||||
if not isinstance(array, list) or len(array) != shape[dimension]:
|
||||
# this branch is for such a case:
|
||||
#
|
||||
# table1 = hnp.LookupTable([1, 3, 2, 0])
|
||||
# table2 = hnp.LookupTable([0, 2, 3, 1])
|
||||
#
|
||||
# multitable = hnp.MultiLookupTable(
|
||||
# [
|
||||
# [table1, table2],
|
||||
# [table2, table1, table2],
|
||||
# ],
|
||||
# )
|
||||
raise ValueError(
|
||||
f"MultiLookupTable should have the shape {shape} but it does not "
|
||||
f"(an array on dimension {dimension} has the size {len(array)} "
|
||||
f"but its size should have been {shape[dimension]} "
|
||||
f"as the expected shape is {shape})"
|
||||
)
|
||||
|
||||
for item in array:
|
||||
MultiLookupTable._check_shape_and_record_luts(
|
||||
item,
|
||||
dimension + 1,
|
||||
shape,
|
||||
table_sizes,
|
||||
table_output_dtypes,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _checked_indexing(x, input_shape, tables):
|
||||
try:
|
||||
result = []
|
||||
for indices in itertools.product(*[range(dimension) for dimension in input_shape]):
|
||||
which_table_to_use = tables
|
||||
what_value_to_use = x
|
||||
where_to_append = result
|
||||
|
||||
for index in indices[:-1]:
|
||||
which_table_to_use = tables[index]
|
||||
what_value_to_use = x[index]
|
||||
|
||||
if len(where_to_append) == index:
|
||||
where_to_append.append([])
|
||||
where_to_append = result[index]
|
||||
|
||||
which_table_to_use = which_table_to_use[indices[-1]]
|
||||
what_value_to_use = what_value_to_use[indices[-1]]
|
||||
where_to_append.append(which_table_to_use[what_value_to_use])
|
||||
except Exception as error:
|
||||
raise ValueError(
|
||||
f"Multiple Lookup Table of shape {input_shape} cannot be looked up with {x} "
|
||||
f"(you should check your inputset)",
|
||||
) from error
|
||||
|
||||
return result
|
||||
@@ -66,6 +66,8 @@ def check_node_compatibility_with_mlir(node: IntermediateNode, is_output: bool)
|
||||
|
||||
elif isinstance(node, intermediate.UnivariateFunction): # constraints for univariate functions
|
||||
assert_true(len(inputs) == 1)
|
||||
if node.op_name == "MultiTLU":
|
||||
return "direct multi table lookup is not supported for the time being"
|
||||
if not value_is_scalar(inputs[0]) or not value_is_unsigned_integer(inputs[0]):
|
||||
return "only unsigned integer scalar lookup tables are supported"
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from ..common.compilation import CompilationArtifacts, CompilationConfiguration
|
||||
from ..common.data_types import Float, Float32, Float64, Integer, SignedInteger, UnsignedInteger
|
||||
from ..common.debugging import draw_graph, get_printable_graph
|
||||
from ..common.extensions.multi_table import MultiLookupTable
|
||||
from ..common.extensions.table import LookupTable
|
||||
from ..common.values import ClearScalar, ClearTensor, EncryptedScalar, EncryptedTensor, TensorValue
|
||||
from .compile import compile_numpy_function, compile_numpy_function_into_op_graph
|
||||
|
||||
@@ -267,12 +267,20 @@ def get_constructor_for_numpy_or_python_constant_data(constant_data: Any):
|
||||
"""
|
||||
|
||||
assert_true(
|
||||
isinstance(constant_data, (int, float, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)),
|
||||
isinstance(
|
||||
constant_data, (int, float, list, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)
|
||||
),
|
||||
f"Unsupported constant data of type {type(constant_data)}",
|
||||
)
|
||||
|
||||
if isinstance(constant_data, list):
|
||||
# this is required because some operations return python lists from their evaluate function
|
||||
# an example of such operation is evaluation of multi tlu during bound measurements
|
||||
constant_data = numpy.array(constant_data)
|
||||
|
||||
if isinstance(constant_data, (numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)):
|
||||
if isinstance(constant_data, numpy.ndarray):
|
||||
return lambda x: numpy.full(constant_data.shape, x, dtype=constant_data.dtype)
|
||||
return constant_data.dtype.type
|
||||
|
||||
return get_constructor_for_python_constant_data(constant_data)
|
||||
|
||||
118
tests/common/extensions/test_multi_table.py
Normal file
118
tests/common/extensions/test_multi_table.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""Test file for direct multi table lookups"""
|
||||
|
||||
import random
|
||||
|
||||
import numpy
|
||||
import pytest
|
||||
|
||||
from concrete.common.data_types.integers import Integer
|
||||
from concrete.common.extensions.multi_table import MultiLookupTable
|
||||
from concrete.common.extensions.table import LookupTable
|
||||
|
||||
table_2b_to_2b = LookupTable([1, 2, 0, 3])
|
||||
table_2b_to_1b = LookupTable([1, 0, 0, 1])
|
||||
table_2b_to_3b = LookupTable([5, 2, 7, 0])
|
||||
|
||||
table_3b_to_2b = LookupTable([1, 2, 0, 3, 0, 3, 1, 2])
|
||||
table_3b_to_1b = LookupTable([1, 0, 0, 1, 1, 1, 1, 0])
|
||||
table_3b_to_3b = LookupTable([5, 2, 7, 0, 4, 1, 6, 2])
|
||||
|
||||
tables_2b = [table_2b_to_1b, table_2b_to_2b, table_2b_to_3b]
|
||||
tables_3b = [table_3b_to_1b, table_3b_to_2b, table_3b_to_3b]
|
||||
|
||||
|
||||
def test_multi_lookup_table_creation_and_indexing():
|
||||
"""Test function for creating and indexing multi lookup tables"""
|
||||
tables = [
|
||||
[tables_2b[random.randint(0, 2)], tables_2b[random.randint(0, 2)]],
|
||||
[tables_2b[random.randint(0, 2)], tables_2b[random.randint(0, 2)]],
|
||||
[tables_2b[random.randint(0, 2)], tables_2b[random.randint(0, 2)]],
|
||||
]
|
||||
multitable = MultiLookupTable(tables)
|
||||
|
||||
assert multitable.input_shape == (3, 2)
|
||||
|
||||
assert isinstance(multitable.output_dtype, Integer)
|
||||
assert multitable.output_dtype.bit_width <= 3
|
||||
|
||||
index = numpy.random.randint(0, 2 ** 2, size=multitable.input_shape).tolist()
|
||||
result = multitable[index]
|
||||
|
||||
for i in range(3):
|
||||
for j in range(2):
|
||||
assert result[i][j] == multitable.tables[i][j][index[i][j]], f"i={i}, j={j}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tables,match",
|
||||
[
|
||||
pytest.param(
|
||||
[
|
||||
[],
|
||||
[table_2b_to_2b, table_2b_to_3b],
|
||||
],
|
||||
"MultiLookupTable cannot have an empty array within it",
|
||||
),
|
||||
pytest.param(
|
||||
[
|
||||
[table_2b_to_1b, 42.0],
|
||||
[table_2b_to_2b, table_2b_to_3b],
|
||||
],
|
||||
"MultiLookupTable should have been made out of LookupTables "
|
||||
"but it had an object of type float within it",
|
||||
),
|
||||
pytest.param(
|
||||
[
|
||||
[table_2b_to_2b],
|
||||
[table_2b_to_2b, table_2b_to_3b],
|
||||
[table_2b_to_2b, table_2b_to_1b],
|
||||
],
|
||||
"MultiLookupTable should have the shape (3, 1) but it does not "
|
||||
"(an array on dimension 1 has the size 2 but its size should have been 1 "
|
||||
"as the expected shape is (3, 1))",
|
||||
),
|
||||
pytest.param(
|
||||
[
|
||||
[table_2b_to_2b, table_3b_to_3b],
|
||||
[table_2b_to_2b, table_3b_to_1b],
|
||||
],
|
||||
"LookupTables within a MultiLookupTable should have the same size but they do not "
|
||||
"(there was a table with the size of 4 and another with the size of 8)",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_multi_lookup_table_creation_failure(tables, match):
|
||||
"""Test function for failing to create multi lookup tables"""
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
MultiLookupTable(tables)
|
||||
|
||||
assert str(excinfo.value) == match
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tables,index,match",
|
||||
[
|
||||
pytest.param(
|
||||
[
|
||||
[table_2b_to_2b, table_2b_to_1b, table_2b_to_3b],
|
||||
[table_2b_to_1b, table_2b_to_2b, table_2b_to_3b],
|
||||
],
|
||||
[
|
||||
[1, 2],
|
||||
[3, 0],
|
||||
],
|
||||
"Multiple Lookup Table of shape (2, 3) cannot be looked up with [[1, 2], [3, 0]] "
|
||||
"(you should check your inputset)",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_multi_lookup_table_indexing_failure(tables, index, match):
|
||||
"""Test function for failing to index multi lookup tables"""
|
||||
|
||||
table = MultiLookupTable(tables)
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
table.__getitem__(index)
|
||||
|
||||
assert str(excinfo.value) == match
|
||||
@@ -9,6 +9,7 @@ import pytest
|
||||
from concrete.common.compilation import CompilationConfiguration
|
||||
from concrete.common.data_types.integers import Integer, UnsignedInteger
|
||||
from concrete.common.debugging import draw_graph, get_printable_graph
|
||||
from concrete.common.extensions.multi_table import MultiLookupTable
|
||||
from concrete.common.extensions.table import LookupTable
|
||||
from concrete.common.values import ClearTensor, EncryptedScalar, EncryptedTensor
|
||||
from concrete.numpy import tracing
|
||||
@@ -123,6 +124,19 @@ def random_lut_7b(x):
|
||||
return table[x]
|
||||
|
||||
|
||||
def multi_lut(x):
|
||||
"""2-bit multi table lookup"""
|
||||
|
||||
table = MultiLookupTable(
|
||||
[
|
||||
[LookupTable([1, 2, 1, 0]), LookupTable([2, 2, 1, 3])],
|
||||
[LookupTable([1, 0, 1, 0]), LookupTable([0, 2, 3, 3])],
|
||||
[LookupTable([0, 2, 3, 0]), LookupTable([2, 1, 2, 0])],
|
||||
]
|
||||
)
|
||||
return table[x]
|
||||
|
||||
|
||||
def small_fused_table(x):
|
||||
"""Test with a small fused table"""
|
||||
return (10 * (numpy.cos(x + 1) + 1)).astype(numpy.uint32)
|
||||
@@ -924,6 +938,21 @@ return(%7)
|
||||
"return(%2)\n"
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
multi_lut,
|
||||
{"x": EncryptedTensor(UnsignedInteger(2), shape=(3, 2))},
|
||||
[(numpy.random.randint(0, 2 ** 2, size=(3, 2)),) for _ in range(32)],
|
||||
(
|
||||
"""
|
||||
function you are trying to compile isn't supported for MLIR lowering
|
||||
|
||||
%0 = x # EncryptedTensor<Integer<unsigned, 2 bits>, shape=(3, 2)>
|
||||
%1 = MultiTLU(%0) # EncryptedTensor<Integer<unsigned, 2 bits>, shape=(3, 2)>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ direct multi table lookup is not supported for the time being
|
||||
return(%1)
|
||||
""".lstrip() # noqa: E501
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
# pylint: enable=line-too-long
|
||||
|
||||
Reference in New Issue
Block a user