feat(extensions): create multi table lookup extension

This commit is contained in:
Umut
2021-10-27 17:54:33 +03:00
parent d4e5831a57
commit 39c16038c7
7 changed files with 382 additions and 2 deletions

View File

@@ -1,2 +1,2 @@
"""Extensions module to provide additional functionality to our users."""
from . import table
from . import multi_table, table

View File

@@ -0,0 +1,222 @@
"""This file contains a wrapper class for direct multi table lookups."""
import itertools
from copy import deepcopy
from typing import List, Tuple, Union
from ..data_types.base import BaseDataType
from ..data_types.dtypes_helpers import find_type_to_hold_both_lossy
from ..representation.intermediate import UnivariateFunction
from ..tracing.base_tracer import BaseTracer
from .table import LookupTable
class MultiLookupTable:
"""Class representing a multi lookup table."""
# Multi table lookup is needed when you want to perform a lookup on a tensor,
# but you want each element to be used with a different lookup table.
#
# Here is an example:
#
# You have x which is of shape (2, 3),
# you want the first row to be indexed with `table1 = LookupTable([2, 3, 1, 0])`
# and the second row to be indexed with `table1 = LookupTable([0, 1, 3, 2])`
#
# You can create such a multi lookup table
# multitable = MultiLookupTable(
# [
# [table1, table1, table1],
# [table2, table2, table2],
# ],
# )
# (notice the shape of multitable matches with the shape of x)
#
# and use multitable[x] toget the following result
# assert multitable[x] == [
# [table1[x[0, 0]], table1[x[0, 1]], table1[x[0, 2]]],
# [table2[x[1, 0]], table2[x[1, 1]], table2[x[1, 2]]],
# ]
# underlying lookup tables
tables: List
# shape of the input of the lookup
input_shape: Tuple[int, ...]
# type of the result of the lookup
output_dtype: BaseDataType
def __init__(self, tables: List):
input_shape_list: List[int] = []
MultiLookupTable._extract_shape_using_first_elements_only(tables, input_shape_list)
input_shape: Tuple[int, ...] = tuple(input_shape_list)
table_sizes: List[int] = []
table_output_dtypes: List[BaseDataType] = []
MultiLookupTable._check_shape_and_record_luts(
tables,
0,
input_shape,
table_sizes,
table_output_dtypes,
)
for i in range(1, len(table_sizes)):
if table_sizes[i - 1] != table_sizes[i]:
# this branch is for such a case:
#
# table1 = hnp.LookupTable([1, 3])
# table2 = hnp.LookupTable([0, 2, 3, 1])
#
# multitable = hnp.MultiLookupTable(
# [
# [table1, table2, table1],
# [table2, table1, table2],
# ],
# )
raise ValueError(
f"LookupTables within a MultiLookupTable "
f"should have the same size but they do not "
f"(there was a table with the size of {table_sizes[i - 1]} "
f"and another with the size of {table_sizes[i]})"
)
output_dtype = table_output_dtypes[0]
for table_output_dtype in table_output_dtypes:
output_dtype = find_type_to_hold_both_lossy(output_dtype, table_output_dtype)
self.tables = tables
self.input_shape = input_shape
self.output_dtype = output_dtype
def __getitem__(self, key: Union[int, BaseTracer]):
# this branch is used during tracing and the regular flow is used during evaluation
if isinstance(key, BaseTracer):
traced_computation = UnivariateFunction(
input_base_value=key.output,
arbitrary_func=MultiLookupTable._checked_indexing,
output_dtype=self.output_dtype,
op_kwargs={
"input_shape": deepcopy(self.input_shape),
"tables": deepcopy(self.tables),
},
op_name="MultiTLU",
)
return key.__class__(
inputs=[key],
traced_computation=traced_computation,
output_idx=0,
)
# if not, it means table is indexed with a constant
# thus, the result of the lookup is a constant
# so, we can propagate it directly
return MultiLookupTable._checked_indexing(key, self.input_shape, self.tables)
@staticmethod
def _extract_shape_using_first_elements_only(array, shape):
if not isinstance(array, list):
# base case for recursion
# the shape is already accumulated up to this point
# so we just return
return
if len(array) == 0:
# this branch is for such a case:
#
# table1 = hnp.LookupTable([1, 3, 2, 0])
# table2 = hnp.LookupTable([0, 2, 3, 1])
#
# multitable = hnp.MultiLookupTable(
# [
# [],
# [table1, table2, table1],
# [table2, table1, table2],
# ],
# )
raise ValueError("MultiLookupTable cannot have an empty array within it")
shape.append(len(array))
MultiLookupTable._extract_shape_using_first_elements_only(array[0], shape)
@staticmethod
def _check_shape_and_record_luts(array, dimension, shape, table_sizes, table_output_dtypes):
if dimension == len(shape):
if not isinstance(array, LookupTable):
# this branch is for such a case:
#
# table1 = hnp.LookupTable([1, 3, 2, 0])
# table2 = hnp.LookupTable([0, 2, 3, 1])
#
# multitable = hnp.MultiLookupTable(
# [
# [table1, table2, 4],
# [table2, table1, table2],
# ],
# )
raise ValueError(
f"MultiLookupTable should have been made out of LookupTables "
f"but it had an object of type {array.__class__.__name__} within it"
)
table_sizes.append(len(array.table))
table_output_dtypes.append(array.output_dtype)
return
if not isinstance(array, list) or len(array) != shape[dimension]:
# this branch is for such a case:
#
# table1 = hnp.LookupTable([1, 3, 2, 0])
# table2 = hnp.LookupTable([0, 2, 3, 1])
#
# multitable = hnp.MultiLookupTable(
# [
# [table1, table2],
# [table2, table1, table2],
# ],
# )
raise ValueError(
f"MultiLookupTable should have the shape {shape} but it does not "
f"(an array on dimension {dimension} has the size {len(array)} "
f"but its size should have been {shape[dimension]} "
f"as the expected shape is {shape})"
)
for item in array:
MultiLookupTable._check_shape_and_record_luts(
item,
dimension + 1,
shape,
table_sizes,
table_output_dtypes,
)
@staticmethod
def _checked_indexing(x, input_shape, tables):
try:
result = []
for indices in itertools.product(*[range(dimension) for dimension in input_shape]):
which_table_to_use = tables
what_value_to_use = x
where_to_append = result
for index in indices[:-1]:
which_table_to_use = tables[index]
what_value_to_use = x[index]
if len(where_to_append) == index:
where_to_append.append([])
where_to_append = result[index]
which_table_to_use = which_table_to_use[indices[-1]]
what_value_to_use = what_value_to_use[indices[-1]]
where_to_append.append(which_table_to_use[what_value_to_use])
except Exception as error:
raise ValueError(
f"Multiple Lookup Table of shape {input_shape} cannot be looked up with {x} "
f"(you should check your inputset)",
) from error
return result

View File

@@ -66,6 +66,8 @@ def check_node_compatibility_with_mlir(node: IntermediateNode, is_output: bool)
elif isinstance(node, intermediate.UnivariateFunction): # constraints for univariate functions
assert_true(len(inputs) == 1)
if node.op_name == "MultiTLU":
return "direct multi table lookup is not supported for the time being"
if not value_is_scalar(inputs[0]) or not value_is_unsigned_integer(inputs[0]):
return "only unsigned integer scalar lookup tables are supported"

View File

@@ -3,6 +3,7 @@
from ..common.compilation import CompilationArtifacts, CompilationConfiguration
from ..common.data_types import Float, Float32, Float64, Integer, SignedInteger, UnsignedInteger
from ..common.debugging import draw_graph, get_printable_graph
from ..common.extensions.multi_table import MultiLookupTable
from ..common.extensions.table import LookupTable
from ..common.values import ClearScalar, ClearTensor, EncryptedScalar, EncryptedTensor, TensorValue
from .compile import compile_numpy_function, compile_numpy_function_into_op_graph

View File

@@ -267,12 +267,20 @@ def get_constructor_for_numpy_or_python_constant_data(constant_data: Any):
"""
assert_true(
isinstance(constant_data, (int, float, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)),
isinstance(
constant_data, (int, float, list, numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)
),
f"Unsupported constant data of type {type(constant_data)}",
)
if isinstance(constant_data, list):
# this is required because some operations return python lists from their evaluate function
# an example of such operation is evaluation of multi tlu during bound measurements
constant_data = numpy.array(constant_data)
if isinstance(constant_data, (numpy.ndarray, SUPPORTED_NUMPY_DTYPES_CLASS_TYPES)):
if isinstance(constant_data, numpy.ndarray):
return lambda x: numpy.full(constant_data.shape, x, dtype=constant_data.dtype)
return constant_data.dtype.type
return get_constructor_for_python_constant_data(constant_data)

View File

@@ -0,0 +1,118 @@
"""Test file for direct multi table lookups"""
import random
import numpy
import pytest
from concrete.common.data_types.integers import Integer
from concrete.common.extensions.multi_table import MultiLookupTable
from concrete.common.extensions.table import LookupTable
table_2b_to_2b = LookupTable([1, 2, 0, 3])
table_2b_to_1b = LookupTable([1, 0, 0, 1])
table_2b_to_3b = LookupTable([5, 2, 7, 0])
table_3b_to_2b = LookupTable([1, 2, 0, 3, 0, 3, 1, 2])
table_3b_to_1b = LookupTable([1, 0, 0, 1, 1, 1, 1, 0])
table_3b_to_3b = LookupTable([5, 2, 7, 0, 4, 1, 6, 2])
tables_2b = [table_2b_to_1b, table_2b_to_2b, table_2b_to_3b]
tables_3b = [table_3b_to_1b, table_3b_to_2b, table_3b_to_3b]
def test_multi_lookup_table_creation_and_indexing():
"""Test function for creating and indexing multi lookup tables"""
tables = [
[tables_2b[random.randint(0, 2)], tables_2b[random.randint(0, 2)]],
[tables_2b[random.randint(0, 2)], tables_2b[random.randint(0, 2)]],
[tables_2b[random.randint(0, 2)], tables_2b[random.randint(0, 2)]],
]
multitable = MultiLookupTable(tables)
assert multitable.input_shape == (3, 2)
assert isinstance(multitable.output_dtype, Integer)
assert multitable.output_dtype.bit_width <= 3
index = numpy.random.randint(0, 2 ** 2, size=multitable.input_shape).tolist()
result = multitable[index]
for i in range(3):
for j in range(2):
assert result[i][j] == multitable.tables[i][j][index[i][j]], f"i={i}, j={j}"
@pytest.mark.parametrize(
"tables,match",
[
pytest.param(
[
[],
[table_2b_to_2b, table_2b_to_3b],
],
"MultiLookupTable cannot have an empty array within it",
),
pytest.param(
[
[table_2b_to_1b, 42.0],
[table_2b_to_2b, table_2b_to_3b],
],
"MultiLookupTable should have been made out of LookupTables "
"but it had an object of type float within it",
),
pytest.param(
[
[table_2b_to_2b],
[table_2b_to_2b, table_2b_to_3b],
[table_2b_to_2b, table_2b_to_1b],
],
"MultiLookupTable should have the shape (3, 1) but it does not "
"(an array on dimension 1 has the size 2 but its size should have been 1 "
"as the expected shape is (3, 1))",
),
pytest.param(
[
[table_2b_to_2b, table_3b_to_3b],
[table_2b_to_2b, table_3b_to_1b],
],
"LookupTables within a MultiLookupTable should have the same size but they do not "
"(there was a table with the size of 4 and another with the size of 8)",
),
],
)
def test_multi_lookup_table_creation_failure(tables, match):
"""Test function for failing to create multi lookup tables"""
with pytest.raises(ValueError) as excinfo:
MultiLookupTable(tables)
assert str(excinfo.value) == match
@pytest.mark.parametrize(
"tables,index,match",
[
pytest.param(
[
[table_2b_to_2b, table_2b_to_1b, table_2b_to_3b],
[table_2b_to_1b, table_2b_to_2b, table_2b_to_3b],
],
[
[1, 2],
[3, 0],
],
"Multiple Lookup Table of shape (2, 3) cannot be looked up with [[1, 2], [3, 0]] "
"(you should check your inputset)",
),
],
)
def test_multi_lookup_table_indexing_failure(tables, index, match):
"""Test function for failing to index multi lookup tables"""
table = MultiLookupTable(tables)
with pytest.raises(ValueError) as excinfo:
table.__getitem__(index)
assert str(excinfo.value) == match

View File

@@ -9,6 +9,7 @@ import pytest
from concrete.common.compilation import CompilationConfiguration
from concrete.common.data_types.integers import Integer, UnsignedInteger
from concrete.common.debugging import draw_graph, get_printable_graph
from concrete.common.extensions.multi_table import MultiLookupTable
from concrete.common.extensions.table import LookupTable
from concrete.common.values import ClearTensor, EncryptedScalar, EncryptedTensor
from concrete.numpy import tracing
@@ -123,6 +124,19 @@ def random_lut_7b(x):
return table[x]
def multi_lut(x):
"""2-bit multi table lookup"""
table = MultiLookupTable(
[
[LookupTable([1, 2, 1, 0]), LookupTable([2, 2, 1, 3])],
[LookupTable([1, 0, 1, 0]), LookupTable([0, 2, 3, 3])],
[LookupTable([0, 2, 3, 0]), LookupTable([2, 1, 2, 0])],
]
)
return table[x]
def small_fused_table(x):
"""Test with a small fused table"""
return (10 * (numpy.cos(x + 1) + 1)).astype(numpy.uint32)
@@ -924,6 +938,21 @@ return(%7)
"return(%2)\n"
),
),
pytest.param(
multi_lut,
{"x": EncryptedTensor(UnsignedInteger(2), shape=(3, 2))},
[(numpy.random.randint(0, 2 ** 2, size=(3, 2)),) for _ in range(32)],
(
"""
function you are trying to compile isn't supported for MLIR lowering
%0 = x # EncryptedTensor<Integer<unsigned, 2 bits>, shape=(3, 2)>
%1 = MultiTLU(%0) # EncryptedTensor<Integer<unsigned, 2 bits>, shape=(3, 2)>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ direct multi table lookup is not supported for the time being
return(%1)
""".lstrip() # noqa: E501
),
),
],
)
# pylint: enable=line-too-long