feat(extensions): add and test support for negative direct table lookups

This commit is contained in:
Umut
2021-11-30 16:10:59 +03:00
parent 8b0a793cda
commit 1b5785e058
4 changed files with 78 additions and 29 deletions

View File

@@ -50,6 +50,42 @@ def identity_lut_generator(n):
return lambda x: LookupTable(list(range(2 ** n)))[x]
def negative_identity_smaller_lut_generator(n):
"""Test negative lookup table"""
table = LookupTable(range(2 ** (n - 1)))
offset = 2 ** (n - 1)
return (lambda x: table[x + (-offset)]), table
def negative_identity_lut_generator(n):
"""Test negative lookup table (bigger than bit-width)"""
table = LookupTable(range(2 ** n))
offset = 2 ** (n - 1)
return (lambda x: table[x + (-offset)]), table
def negative_identity_bigger_lut_generator(n):
"""Test negative lookup table (bigger than bit-width)"""
table = LookupTable(range(2 ** (n + 1)))
offset = 2 ** (n - 1)
return (lambda x: table[x + (-offset)]), table
def weird_lut(n):
"""A weird lookup table to test an edge case related to negative indexing"""
table = LookupTable([0, 1, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 4, 5, 6, 7])
offset = 2 ** (n - 1)
return (lambda x: table[x + (-offset)]), table
def random_lut_1b(x):
"""1-bit random table lookup"""
@@ -1375,6 +1411,45 @@ def test_compile_and_run_lut_correctness(
check_is_good_execution(compiler_engine, function, args)
@pytest.mark.parametrize(
"function,table,bit_width",
[
pytest.param(*negative_identity_smaller_lut_generator(n), n, id=f"smaller ({n}-bit)")
for n in range(1, 8)
]
+ [
pytest.param(*negative_identity_lut_generator(n), n, id=f"normal ({n}-bit)")
for n in range(1, 8)
]
+ [
pytest.param(*negative_identity_bigger_lut_generator(n), n, id=f"bigger ({n}-bit)")
for n in range(1, 7)
]
+ [
pytest.param(*weird_lut(3), 3, id="weird"),
],
)
def test_compile_and_run_negative_lut_correctness(
function,
table,
bit_width,
default_compilation_configuration,
):
"""Test correctness when running a compiled function with LUT using negative values"""
circuit = compile_numpy_function(
function,
{"x": EncryptedScalar(UnsignedInteger(bit_width))},
range(2 ** bit_width),
default_compilation_configuration,
)
offset = 2 ** (bit_width - 1)
for value in range(-offset, offset):
assert table[value] == function(value + offset)
check_is_good_execution(circuit, function, [value + offset])
def test_compile_and_run_multi_lut_correctness(default_compilation_configuration):
"""Test correctness of results when running a compiled function with Multi LUT"""