diff --git a/concrete/common/extensions/__init__.py b/concrete/common/extensions/__init__.py index 88bbd5e07..b99aba1fa 100644 --- a/concrete/common/extensions/__init__.py +++ b/concrete/common/extensions/__init__.py @@ -1,2 +1,2 @@ """Extensions module to provide additional functionality to our users.""" -from . import table +from . import multi_table, table diff --git a/concrete/common/extensions/multi_table.py b/concrete/common/extensions/multi_table.py new file mode 100644 index 000000000..c9b9345fb --- /dev/null +++ b/concrete/common/extensions/multi_table.py @@ -0,0 +1,222 @@ +"""This file contains a wrapper class for direct multi table lookups.""" + +import itertools +from copy import deepcopy +from typing import List, Tuple, Union + +from ..data_types.base import BaseDataType +from ..data_types.dtypes_helpers import find_type_to_hold_both_lossy +from ..representation.intermediate import UnivariateFunction +from ..tracing.base_tracer import BaseTracer +from .table import LookupTable + + +class MultiLookupTable: + """Class representing a multi lookup table.""" + + # Multi table lookup is needed when you want to perform a lookup on a tensor, + # but you want each element to be used with a different lookup table. + # + # Here is an example: + # + # You have x which is of shape (2, 3), + # you want the first row to be indexed with `table1 = LookupTable([2, 3, 1, 0])` + # and the second row to be indexed with `table1 = LookupTable([0, 1, 3, 2])` + # + # You can create such a multi lookup table + # multitable = MultiLookupTable( + # [ + # [table1, table1, table1], + # [table2, table2, table2], + # ], + # ) + # (notice the shape of multitable matches with the shape of x) + # + # and use multitable[x] toget the following result + # assert multitable[x] == [ + # [table1[x[0, 0]], table1[x[0, 1]], table1[x[0, 2]]], + # [table2[x[1, 0]], table2[x[1, 1]], table2[x[1, 2]]], + # ] + + # underlying lookup tables + tables: List + + # shape of the input of the lookup + input_shape: Tuple[int, ...] + + # type of the result of the lookup + output_dtype: BaseDataType + + def __init__(self, tables: List): + input_shape_list: List[int] = [] + MultiLookupTable._extract_shape_using_first_elements_only(tables, input_shape_list) + input_shape: Tuple[int, ...] = tuple(input_shape_list) + + table_sizes: List[int] = [] + table_output_dtypes: List[BaseDataType] = [] + MultiLookupTable._check_shape_and_record_luts( + tables, + 0, + input_shape, + table_sizes, + table_output_dtypes, + ) + + for i in range(1, len(table_sizes)): + if table_sizes[i - 1] != table_sizes[i]: + # this branch is for such a case: + # + # table1 = hnp.LookupTable([1, 3]) + # table2 = hnp.LookupTable([0, 2, 3, 1]) + # + # multitable = hnp.MultiLookupTable( + # [ + # [table1, table2, table1], + # [table2, table1, table2], + # ], + # ) + raise ValueError( + f"LookupTables within a MultiLookupTable " + f"should have the same size but they do not " + f"(there was a table with the size of {table_sizes[i - 1]} " + f"and another with the size of {table_sizes[i]})" + ) + + output_dtype = table_output_dtypes[0] + for table_output_dtype in table_output_dtypes: + output_dtype = find_type_to_hold_both_lossy(output_dtype, table_output_dtype) + + self.tables = tables + self.input_shape = input_shape + self.output_dtype = output_dtype + + def __getitem__(self, key: Union[int, BaseTracer]): + # this branch is used during tracing and the regular flow is used during evaluation + if isinstance(key, BaseTracer): + traced_computation = UnivariateFunction( + input_base_value=key.output, + arbitrary_func=MultiLookupTable._checked_indexing, + output_dtype=self.output_dtype, + op_kwargs={ + "input_shape": deepcopy(self.input_shape), + "tables": deepcopy(self.tables), + }, + op_name="MultiTLU", + ) + return key.__class__( + inputs=[key], + traced_computation=traced_computation, + output_idx=0, + ) + + # if not, it means table is indexed with a constant + # thus, the result of the lookup is a constant + # so, we can propagate it directly + return MultiLookupTable._checked_indexing(key, self.input_shape, self.tables) + + @staticmethod + def _extract_shape_using_first_elements_only(array, shape): + if not isinstance(array, list): + # base case for recursion + # the shape is already accumulated up to this point + # so we just return + return + + if len(array) == 0: + # this branch is for such a case: + # + # table1 = hnp.LookupTable([1, 3, 2, 0]) + # table2 = hnp.LookupTable([0, 2, 3, 1]) + # + # multitable = hnp.MultiLookupTable( + # [ + # [], + # [table1, table2, table1], + # [table2, table1, table2], + # ], + # ) + + raise ValueError("MultiLookupTable cannot have an empty array within it") + + shape.append(len(array)) + MultiLookupTable._extract_shape_using_first_elements_only(array[0], shape) + + @staticmethod + def _check_shape_and_record_luts(array, dimension, shape, table_sizes, table_output_dtypes): + if dimension == len(shape): + if not isinstance(array, LookupTable): + # this branch is for such a case: + # + # table1 = hnp.LookupTable([1, 3, 2, 0]) + # table2 = hnp.LookupTable([0, 2, 3, 1]) + # + # multitable = hnp.MultiLookupTable( + # [ + # [table1, table2, 4], + # [table2, table1, table2], + # ], + # ) + raise ValueError( + f"MultiLookupTable should have been made out of LookupTables " + f"but it had an object of type {array.__class__.__name__} within it" + ) + + table_sizes.append(len(array.table)) + table_output_dtypes.append(array.output_dtype) + return + + if not isinstance(array, list) or len(array) != shape[dimension]: + # this branch is for such a case: + # + # table1 = hnp.LookupTable([1, 3, 2, 0]) + # table2 = hnp.LookupTable([0, 2, 3, 1]) + # + # multitable = hnp.MultiLookupTable( + # [ + # [table1, table2], + # [table2, table1, table2], + # ], + # ) + raise ValueError( + f"MultiLookupTable should have the shape {shape} but it does not " + f"(an array on dimension {dimension} has the size {len(array)} " + f"but its size should have been {shape[dimension]} " + f"as the expected shape is {shape})" + ) + + for item in array: + MultiLookupTable._check_shape_and_record_luts( + item, + dimension + 1, + shape, + table_sizes, + table_output_dtypes, + ) + + @staticmethod + def _checked_indexing(x, input_shape, tables): + try: + result = [] + for indices in itertools.product(*[range(dimension) for dimension in input_shape]): + which_table_to_use = tables + what_value_to_use = x + where_to_append = result + + for index in indices[:-1]: + which_table_to_use = tables[index] + what_value_to_use = x[index] + + if len(where_to_append) == index: + where_to_append.append([]) + where_to_append = result[index] + + which_table_to_use = which_table_to_use[indices[-1]] + what_value_to_use = what_value_to_use[indices[-1]] + where_to_append.append(which_table_to_use[what_value_to_use]) + except Exception as error: + raise ValueError( + f"Multiple Lookup Table of shape {input_shape} cannot be looked up with {x} " + f"(you should check your inputset)", + ) from error + + return result diff --git a/concrete/common/mlir/utils.py b/concrete/common/mlir/utils.py index 20ea6d236..f1d005c47 100644 --- a/concrete/common/mlir/utils.py +++ b/concrete/common/mlir/utils.py @@ -66,6 +66,8 @@ def check_node_compatibility_with_mlir(node: IntermediateNode, is_output: bool) elif isinstance(node, intermediate.UnivariateFunction): # constraints for univariate functions assert_true(len(inputs) == 1) + if node.op_name == "MultiTLU": + return "direct multi table lookup is not supported for the time being" if not value_is_scalar(inputs[0]) or not value_is_unsigned_integer(inputs[0]): return "only unsigned integer scalar lookup tables are supported" diff --git a/concrete/numpy/__init__.py b/concrete/numpy/__init__.py index 8adaa8c99..60df0c831 100644 --- a/concrete/numpy/__init__.py +++ b/concrete/numpy/__init__.py @@ -3,6 +3,7 @@ from ..common.compilation import CompilationArtifacts, CompilationConfiguration from ..common.data_types import Float, Float32, Float64, Integer, SignedInteger, UnsignedInteger from ..common.debugging import draw_graph, get_printable_graph +from ..common.extensions.multi_table import MultiLookupTable from ..common.extensions.table import LookupTable from ..common.values import ClearScalar, ClearTensor, EncryptedScalar, EncryptedTensor, TensorValue from .compile import compile_numpy_function, compile_numpy_function_into_op_graph diff --git a/concrete/numpy/np_dtypes_helpers.py b/concrete/numpy/np_dtypes_helpers.py index 9ff9db7af..ecf51d31a 100644 --- a/concrete/numpy/np_dtypes_helpers.py +++ b/concrete/numpy/np_dtypes_helpers.py @@ -267,12 +267,20 @@ def get_constructor_for_numpy_or_python_constant_data(constant_data: Any): """ assert_true( - isinstance(constant_data, (int, float, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)), + isinstance( + constant_data, (int, float, list, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES) + ), f"Unsupported constant data of type {type(constant_data)}", ) + if isinstance(constant_data, list): + # this is required because some operations return python lists from their evaluate function + # an example of such operation is evaluation of multi tlu during bound measurements + constant_data = numpy.array(constant_data) + if isinstance(constant_data, (numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)): if isinstance(constant_data, numpy.ndarray): return lambda x: numpy.full(constant_data.shape, x, dtype=constant_data.dtype) return constant_data.dtype.type + return get_constructor_for_python_constant_data(constant_data) diff --git a/tests/common/extensions/test_multi_table.py b/tests/common/extensions/test_multi_table.py new file mode 100644 index 000000000..6590e9699 --- /dev/null +++ b/tests/common/extensions/test_multi_table.py @@ -0,0 +1,118 @@ +"""Test file for direct multi table lookups""" + +import random + +import numpy +import pytest + +from concrete.common.data_types.integers import Integer +from concrete.common.extensions.multi_table import MultiLookupTable +from concrete.common.extensions.table import LookupTable + +table_2b_to_2b = LookupTable([1, 2, 0, 3]) +table_2b_to_1b = LookupTable([1, 0, 0, 1]) +table_2b_to_3b = LookupTable([5, 2, 7, 0]) + +table_3b_to_2b = LookupTable([1, 2, 0, 3, 0, 3, 1, 2]) +table_3b_to_1b = LookupTable([1, 0, 0, 1, 1, 1, 1, 0]) +table_3b_to_3b = LookupTable([5, 2, 7, 0, 4, 1, 6, 2]) + +tables_2b = [table_2b_to_1b, table_2b_to_2b, table_2b_to_3b] +tables_3b = [table_3b_to_1b, table_3b_to_2b, table_3b_to_3b] + + +def test_multi_lookup_table_creation_and_indexing(): + """Test function for creating and indexing multi lookup tables""" + tables = [ + [tables_2b[random.randint(0, 2)], tables_2b[random.randint(0, 2)]], + [tables_2b[random.randint(0, 2)], tables_2b[random.randint(0, 2)]], + [tables_2b[random.randint(0, 2)], tables_2b[random.randint(0, 2)]], + ] + multitable = MultiLookupTable(tables) + + assert multitable.input_shape == (3, 2) + + assert isinstance(multitable.output_dtype, Integer) + assert multitable.output_dtype.bit_width <= 3 + + index = numpy.random.randint(0, 2 ** 2, size=multitable.input_shape).tolist() + result = multitable[index] + + for i in range(3): + for j in range(2): + assert result[i][j] == multitable.tables[i][j][index[i][j]], f"i={i}, j={j}" + + +@pytest.mark.parametrize( + "tables,match", + [ + pytest.param( + [ + [], + [table_2b_to_2b, table_2b_to_3b], + ], + "MultiLookupTable cannot have an empty array within it", + ), + pytest.param( + [ + [table_2b_to_1b, 42.0], + [table_2b_to_2b, table_2b_to_3b], + ], + "MultiLookupTable should have been made out of LookupTables " + "but it had an object of type float within it", + ), + pytest.param( + [ + [table_2b_to_2b], + [table_2b_to_2b, table_2b_to_3b], + [table_2b_to_2b, table_2b_to_1b], + ], + "MultiLookupTable should have the shape (3, 1) but it does not " + "(an array on dimension 1 has the size 2 but its size should have been 1 " + "as the expected shape is (3, 1))", + ), + pytest.param( + [ + [table_2b_to_2b, table_3b_to_3b], + [table_2b_to_2b, table_3b_to_1b], + ], + "LookupTables within a MultiLookupTable should have the same size but they do not " + "(there was a table with the size of 4 and another with the size of 8)", + ), + ], +) +def test_multi_lookup_table_creation_failure(tables, match): + """Test function for failing to create multi lookup tables""" + + with pytest.raises(ValueError) as excinfo: + MultiLookupTable(tables) + + assert str(excinfo.value) == match + + +@pytest.mark.parametrize( + "tables,index,match", + [ + pytest.param( + [ + [table_2b_to_2b, table_2b_to_1b, table_2b_to_3b], + [table_2b_to_1b, table_2b_to_2b, table_2b_to_3b], + ], + [ + [1, 2], + [3, 0], + ], + "Multiple Lookup Table of shape (2, 3) cannot be looked up with [[1, 2], [3, 0]] " + "(you should check your inputset)", + ), + ], +) +def test_multi_lookup_table_indexing_failure(tables, index, match): + """Test function for failing to index multi lookup tables""" + + table = MultiLookupTable(tables) + + with pytest.raises(ValueError) as excinfo: + table.__getitem__(index) + + assert str(excinfo.value) == match diff --git a/tests/numpy/test_compile.py b/tests/numpy/test_compile.py index a9169dbad..9fbaa3014 100644 --- a/tests/numpy/test_compile.py +++ b/tests/numpy/test_compile.py @@ -9,6 +9,7 @@ import pytest from concrete.common.compilation import CompilationConfiguration from concrete.common.data_types.integers import Integer, UnsignedInteger from concrete.common.debugging import draw_graph, get_printable_graph +from concrete.common.extensions.multi_table import MultiLookupTable from concrete.common.extensions.table import LookupTable from concrete.common.values import ClearTensor, EncryptedScalar, EncryptedTensor from concrete.numpy import tracing @@ -123,6 +124,19 @@ def random_lut_7b(x): return table[x] +def multi_lut(x): + """2-bit multi table lookup""" + + table = MultiLookupTable( + [ + [LookupTable([1, 2, 1, 0]), LookupTable([2, 2, 1, 3])], + [LookupTable([1, 0, 1, 0]), LookupTable([0, 2, 3, 3])], + [LookupTable([0, 2, 3, 0]), LookupTable([2, 1, 2, 0])], + ] + ) + return table[x] + + def small_fused_table(x): """Test with a small fused table""" return (10 * (numpy.cos(x + 1) + 1)).astype(numpy.uint32) @@ -924,6 +938,21 @@ return(%7) "return(%2)\n" ), ), + pytest.param( + multi_lut, + {"x": EncryptedTensor(UnsignedInteger(2), shape=(3, 2))}, + [(numpy.random.randint(0, 2 ** 2, size=(3, 2)),) for _ in range(32)], + ( + """ +function you are trying to compile isn't supported for MLIR lowering + +%0 = x # EncryptedTensor, shape=(3, 2)> +%1 = MultiTLU(%0) # EncryptedTensor, shape=(3, 2)> +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ direct multi table lookup is not supported for the time being +return(%1) +""".lstrip() # noqa: E501 + ), + ), ], ) # pylint: enable=line-too-long