Files
concrete/frontends/concrete-python/tests/execution/test_direct_table_lookup.py

324 lines
8.6 KiB
Python

"""
Tests of execution of direct table lookup operation.
"""
import numpy as np
import pytest
from concrete import fhe
def identity_table_lookup_generator(n):
"""
Get identity table lookup function.
"""
return lambda x: fhe.LookupTable(range(2**n))[x]
def random_table_lookup_1b(x):
"""
Lookup on a random table with 1-bit input.
"""
# fmt: off
table = fhe.LookupTable([10, 12])
# fmt: on
return table[x]
def random_table_lookup_2b(x):
"""
Lookup on a random table with 2-bit input.
"""
# fmt: off
table = fhe.LookupTable([3, 8, 22, 127])
# fmt: on
return table[x]
def random_table_lookup_3b(x):
"""
Lookup on a random table with 3-bit input.
"""
# fmt: off
table = fhe.LookupTable([30, 52, 125, 23, 17, 12, 90, 4])
# fmt: on
return table[x]
def random_table_lookup_4b(x):
"""
Lookup on a random table with 4-bit input.
"""
# fmt: off
table = fhe.LookupTable([30, 52, 125, 23, 17, 12, 90, 4, 21, 51, 22, 15, 53, 100, 75, 90])
# fmt: on
return table[x]
def random_table_lookup_5b(x):
"""
Lookup on a random table with 5-bit input.
"""
# fmt: off
table = fhe.LookupTable(
[
1, 5, 2, 3, 10, 2, 4, 8, 1, 12, 15, 12, 10, 1, 0, 2,
4, 3, 8, 7, 10, 11, 6, 13, 9, 0, 2, 1, 15, 11, 12, 5
]
)
# fmt: on
return table[x]
def random_table_lookup_6b(x):
"""
Lookup on a random table with 6-bit input.
"""
# fmt: off
table = fhe.LookupTable(
[
95, 74, 11, 83, 24, 116, 28, 75, 26, 85, 114, 121, 91, 123, 78, 69,
72, 115, 67, 5, 39, 11, 120, 88, 56, 43, 74, 16, 72, 85, 103, 92,
44, 115, 50, 56, 107, 77, 25, 71, 52, 45, 80, 35, 69, 8, 40, 87,
26, 85, 84, 53, 73, 95, 86, 22, 16, 45, 59, 112, 53, 113, 98, 116
]
)
# fmt: on
return table[x]
def random_table_lookup_7b(x):
"""
Lookup on a random table with 7-bit input.
"""
# fmt: off
table = fhe.LookupTable(
[
13, 58, 38, 58, 15, 15, 77, 86, 80, 94, 108, 27, 126, 60, 65, 95,
50, 79, 22, 97, 38, 60, 25, 48, 73, 112, 27, 45, 88, 20, 67, 17,
16, 6, 71, 60, 77, 43, 93, 40, 41, 31, 99, 122, 120, 40, 94, 13,
111, 44, 96, 62, 108, 91, 34, 90, 103, 58, 3, 103, 19, 69, 55, 108,
0, 111, 113, 0, 0, 73, 22, 52, 81, 2, 88, 76, 36, 121, 97, 121,
123, 79, 82, 120, 12, 65, 54, 101, 90, 52, 84, 106, 23, 15, 110, 79,
85, 101, 30, 61, 104, 35, 81, 30, 98, 44, 111, 32, 68, 18, 45, 123,
84, 80, 68, 27, 31, 38, 126, 61, 51, 7, 49, 37, 63, 114, 22, 18,
]
)
# fmt: on
return table[x]
def negative_identity_table_lookup_generator(n):
"""
Get negative identity table lookup function.
"""
return lambda x: fhe.LookupTable([-i for i in range(2**n)])[x]
@pytest.mark.parametrize(
"bits,function",
[
pytest.param(1, identity_table_lookup_generator(1)),
pytest.param(2, identity_table_lookup_generator(2)),
pytest.param(3, identity_table_lookup_generator(3)),
pytest.param(4, identity_table_lookup_generator(4)),
pytest.param(5, identity_table_lookup_generator(5)),
pytest.param(6, identity_table_lookup_generator(6)),
pytest.param(7, identity_table_lookup_generator(7)),
pytest.param(1, random_table_lookup_1b),
pytest.param(2, random_table_lookup_2b),
pytest.param(3, random_table_lookup_3b),
pytest.param(4, random_table_lookup_4b),
pytest.param(5, random_table_lookup_5b),
pytest.param(6, random_table_lookup_6b),
pytest.param(7, random_table_lookup_7b),
pytest.param(1, negative_identity_table_lookup_generator(1)),
pytest.param(2, negative_identity_table_lookup_generator(2)),
pytest.param(3, negative_identity_table_lookup_generator(3)),
pytest.param(4, negative_identity_table_lookup_generator(4)),
pytest.param(5, negative_identity_table_lookup_generator(5)),
pytest.param(6, negative_identity_table_lookup_generator(6)),
],
)
def test_direct_table_lookup(bits, function, helpers):
"""
Test direct table lookup.
"""
configuration = helpers.configuration()
# scalar
# ------
compiler = fhe.Compiler(function, {"x": "encrypted"})
inputset = range(2**bits)
circuit = compiler.compile(inputset, configuration)
sample = int(np.random.randint(0, 2**bits))
helpers.check_execution(circuit, function, sample, retries=3)
# tensor
# ------
compiler = fhe.Compiler(function, {"x": "encrypted"})
inputset = [np.random.randint(0, 2**bits, size=(3, 2)) for _ in range(100)]
circuit = compiler.compile(inputset, configuration)
sample = np.random.randint(0, 2**bits, size=(3, 2))
helpers.check_execution(circuit, function, sample, retries=3)
# negative scalar
# ---------------
compiler = fhe.Compiler(function, {"x": "encrypted"})
inputset = range(-(2 ** (bits - 1)), 2 ** (bits - 1))
circuit = compiler.compile(inputset, configuration)
sample = int(np.random.randint(-(2 ** (bits - 1)), 2 ** (bits - 1)))
helpers.check_execution(circuit, function, sample, retries=3)
# negative tensor
# ---------------
compiler = fhe.Compiler(function, {"x": "encrypted"})
inputset = [
np.random.randint(-(2 ** (bits - 1)), 2 ** (bits - 1), size=(3, 2)) for _ in range(100)
]
circuit = compiler.compile(inputset, configuration)
sample = np.random.randint(-(2 ** (bits - 1)), 2 ** (bits - 1), size=(3, 2))
helpers.check_execution(circuit, function, sample, retries=3)
def test_direct_multi_table_lookup(helpers):
"""
Test direct multi table lookup.
"""
configuration = helpers.configuration()
square = fhe.LookupTable([i * i for i in range(4)])
cube = fhe.LookupTable([i * i * i for i in range(4)])
table = fhe.LookupTable(
[
[square, cube],
[cube, square],
[square, cube],
]
)
def function(x):
return table[x]
compiler = fhe.Compiler(function, {"x": "encrypted"})
inputset = [np.random.randint(0, 2**2, size=(3, 2)) for _ in range(100)]
circuit = compiler.compile(inputset, configuration)
sample = np.random.randint(0, 2**2, size=(3, 2))
helpers.check_execution(circuit, function, sample, retries=3)
def test_bad_direct_table_lookup(helpers):
"""
Test direct table lookup with bad parameters.
"""
configuration = helpers.configuration()
# empty table
# -----------
with pytest.raises(ValueError) as excinfo:
fhe.LookupTable([])
assert str(excinfo.value) == "LookupTable cannot be constructed with []"
# invalid table
# -------------
with pytest.raises(ValueError) as excinfo:
fhe.LookupTable([[0, 1], [2, 3]])
assert str(excinfo.value) == "LookupTable cannot be constructed with [[0, 1], [2, 3]]"
# invalid multi table
# -------------------
with pytest.raises(ValueError) as excinfo:
fhe.LookupTable(["abc", 3.2])
assert str(excinfo.value) == "LookupTable cannot be constructed with ['abc', 3.2]"
# simulation with float value
# ---------------------------
with pytest.raises(ValueError) as excinfo:
random_table_lookup_3b(1.1)
assert str(excinfo.value) == "LookupTable cannot be looked up with 1.1"
# simulation with invalid shape
# -----------------------------
square = fhe.LookupTable([i * i for i in range(4)])
cube = fhe.LookupTable([i * i * i for i in range(4)])
table = fhe.LookupTable(
[
[square, cube],
[cube, square],
[square, cube],
]
)
with pytest.raises(ValueError) as excinfo:
_ = table[np.array([1, 2])]
assert str(excinfo.value) == "LookupTable of shape (3, 2) cannot be looked up with [1 2]"
# compilation with float value
# ----------------------------
compiler = fhe.Compiler(random_table_lookup_3b, {"x": "encrypted"})
inputset = [1.5]
with pytest.raises(ValueError) as excinfo:
compiler.compile(inputset, configuration)
assert str(excinfo.value) == "LookupTable cannot be looked up with EncryptedScalar<float64>"
# compilation with invalid shape
# ------------------------------
compiler = fhe.Compiler(lambda x: table[x], {"x": "encrypted"})
inputset = [10, 5, 6, 2]
with pytest.raises(ValueError) as excinfo:
compiler.compile(inputset, configuration)
assert str(excinfo.value) == (
"LookupTable of shape (3, 2) cannot be looked up with EncryptedScalar<uint4>"
)