Files
concrete/tests/common/extensions/test_multi_table.py

119 lines
3.7 KiB
Python

"""Test file for direct multi table lookups"""
import random
import numpy
import pytest
from concrete.common.data_types.integers import Integer
from concrete.common.extensions.multi_table import MultiLookupTable
from concrete.common.extensions.table import LookupTable
table_2b_to_2b = LookupTable([1, 2, 0, 3])
table_2b_to_1b = LookupTable([1, 0, 0, 1])
table_2b_to_3b = LookupTable([5, 2, 7, 0])
table_3b_to_2b = LookupTable([1, 2, 0, 3, 0, 3, 1, 2])
table_3b_to_1b = LookupTable([1, 0, 0, 1, 1, 1, 1, 0])
table_3b_to_3b = LookupTable([5, 2, 7, 0, 4, 1, 6, 2])
tables_2b = [table_2b_to_1b, table_2b_to_2b, table_2b_to_3b]
tables_3b = [table_3b_to_1b, table_3b_to_2b, table_3b_to_3b]
def test_multi_lookup_table_creation_and_indexing():
"""Test function for creating and indexing multi lookup tables"""
tables = [
[tables_2b[random.randint(0, 2)], tables_2b[random.randint(0, 2)]],
[tables_2b[random.randint(0, 2)], tables_2b[random.randint(0, 2)]],
[tables_2b[random.randint(0, 2)], tables_2b[random.randint(0, 2)]],
]
multitable = MultiLookupTable(tables)
assert multitable.input_shape == (3, 2)
assert isinstance(multitable.output_dtype, Integer)
assert multitable.output_dtype.bit_width <= 3
index = numpy.random.randint(0, 2 ** 2, size=multitable.input_shape).tolist()
result = multitable[index]
for i in range(3):
for j in range(2):
assert result[i][j] == multitable.tables[i][j][index[i][j]], f"i={i}, j={j}"
@pytest.mark.parametrize(
"tables,match",
[
pytest.param(
[
[],
[table_2b_to_2b, table_2b_to_3b],
],
"MultiLookupTable cannot have an empty array within it",
),
pytest.param(
[
[table_2b_to_1b, 42.0],
[table_2b_to_2b, table_2b_to_3b],
],
"MultiLookupTable should have been made out of LookupTables "
"but it had an object of type float within it",
),
pytest.param(
[
[table_2b_to_2b],
[table_2b_to_2b, table_2b_to_3b],
[table_2b_to_2b, table_2b_to_1b],
],
"MultiLookupTable should have the shape (3, 1) but it does not "
"(an array on dimension 1 has the size 2 but its size should have been 1 "
"as the expected shape is (3, 1))",
),
pytest.param(
[
[table_2b_to_2b, table_3b_to_3b],
[table_2b_to_2b, table_3b_to_1b],
],
"LookupTables within a MultiLookupTable should have the same size but they do not "
"(there was a table with the size of 4 and another with the size of 8)",
),
],
)
def test_multi_lookup_table_creation_failure(tables, match):
"""Test function for failing to create multi lookup tables"""
with pytest.raises(ValueError) as excinfo:
MultiLookupTable(tables)
assert str(excinfo.value) == match
@pytest.mark.parametrize(
"tables,index,match",
[
pytest.param(
[
[table_2b_to_2b, table_2b_to_1b, table_2b_to_3b],
[table_2b_to_1b, table_2b_to_2b, table_2b_to_3b],
],
[
[1, 2],
[3, 0],
],
"Multiple Lookup Table of shape (2, 3) cannot be looked up with [[1, 2], [3, 0]] "
"(you should check your inputset)",
),
],
)
def test_multi_lookup_table_indexing_failure(tables, index, match):
"""Test function for failing to index multi lookup tables"""
table = MultiLookupTable(tables)
with pytest.raises(ValueError) as excinfo:
table.__getitem__(index)
assert str(excinfo.value) == match