mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat(extensions): add and test support for negative direct table lookups
This commit is contained in:
@@ -50,6 +50,42 @@ def identity_lut_generator(n):
|
||||
return lambda x: LookupTable(list(range(2 ** n)))[x]
|
||||
|
||||
|
||||
def negative_identity_smaller_lut_generator(n):
|
||||
"""Test negative lookup table"""
|
||||
|
||||
table = LookupTable(range(2 ** (n - 1)))
|
||||
offset = 2 ** (n - 1)
|
||||
|
||||
return (lambda x: table[x + (-offset)]), table
|
||||
|
||||
|
||||
def negative_identity_lut_generator(n):
|
||||
"""Test negative lookup table (bigger than bit-width)"""
|
||||
|
||||
table = LookupTable(range(2 ** n))
|
||||
offset = 2 ** (n - 1)
|
||||
|
||||
return (lambda x: table[x + (-offset)]), table
|
||||
|
||||
|
||||
def negative_identity_bigger_lut_generator(n):
|
||||
"""Test negative lookup table (bigger than bit-width)"""
|
||||
|
||||
table = LookupTable(range(2 ** (n + 1)))
|
||||
offset = 2 ** (n - 1)
|
||||
|
||||
return (lambda x: table[x + (-offset)]), table
|
||||
|
||||
|
||||
def weird_lut(n):
|
||||
"""A weird lookup table to test an edge case related to negative indexing"""
|
||||
|
||||
table = LookupTable([0, 1, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 4, 5, 6, 7])
|
||||
offset = 2 ** (n - 1)
|
||||
|
||||
return (lambda x: table[x + (-offset)]), table
|
||||
|
||||
|
||||
def random_lut_1b(x):
|
||||
"""1-bit random table lookup"""
|
||||
|
||||
@@ -1375,6 +1411,45 @@ def test_compile_and_run_lut_correctness(
|
||||
check_is_good_execution(compiler_engine, function, args)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,table,bit_width",
|
||||
[
|
||||
pytest.param(*negative_identity_smaller_lut_generator(n), n, id=f"smaller ({n}-bit)")
|
||||
for n in range(1, 8)
|
||||
]
|
||||
+ [
|
||||
pytest.param(*negative_identity_lut_generator(n), n, id=f"normal ({n}-bit)")
|
||||
for n in range(1, 8)
|
||||
]
|
||||
+ [
|
||||
pytest.param(*negative_identity_bigger_lut_generator(n), n, id=f"bigger ({n}-bit)")
|
||||
for n in range(1, 7)
|
||||
]
|
||||
+ [
|
||||
pytest.param(*weird_lut(3), 3, id="weird"),
|
||||
],
|
||||
)
|
||||
def test_compile_and_run_negative_lut_correctness(
|
||||
function,
|
||||
table,
|
||||
bit_width,
|
||||
default_compilation_configuration,
|
||||
):
|
||||
"""Test correctness when running a compiled function with LUT using negative values"""
|
||||
|
||||
circuit = compile_numpy_function(
|
||||
function,
|
||||
{"x": EncryptedScalar(UnsignedInteger(bit_width))},
|
||||
range(2 ** bit_width),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
offset = 2 ** (bit_width - 1)
|
||||
for value in range(-offset, offset):
|
||||
assert table[value] == function(value + offset)
|
||||
check_is_good_execution(circuit, function, [value + offset])
|
||||
|
||||
|
||||
def test_compile_and_run_multi_lut_correctness(default_compilation_configuration):
|
||||
"""Test correctness of results when running a compiled function with Multi LUT"""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user