mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-10 12:44:57 -05:00
test(execution): add table lookup correctness tests
This commit is contained in:
@@ -22,15 +22,101 @@ def no_fuse_unhandled(x, y):
|
||||
return intermediate.astype(numpy.int32)
|
||||
|
||||
|
||||
def lut(x):
|
||||
def identity_lut_generator(n):
|
||||
"""Test lookup table"""
|
||||
table = LookupTable(list(range(128)))
|
||||
return lambda x: LookupTable(list(range(2 ** n)))[x]
|
||||
|
||||
|
||||
def random_lut_1b(x):
|
||||
"""1-bit random table lookup"""
|
||||
|
||||
# fmt: off
|
||||
table = LookupTable([10, 12])
|
||||
# fmt: on
|
||||
|
||||
return table[x]
|
||||
|
||||
|
||||
def small_lut(x):
|
||||
"""Test lookup table with small size and output"""
|
||||
table = LookupTable(list(range(32)))
|
||||
def random_lut_2b(x):
|
||||
"""2-bit random table lookup"""
|
||||
|
||||
# fmt: off
|
||||
table = LookupTable([3, 8, 22, 127])
|
||||
# fmt: on
|
||||
|
||||
return table[x]
|
||||
|
||||
|
||||
def random_lut_3b(x):
|
||||
"""3-bit random table lookup"""
|
||||
|
||||
# fmt: off
|
||||
table = LookupTable([30, 52, 125, 23, 17, 12, 90, 4])
|
||||
# fmt: on
|
||||
|
||||
return table[x]
|
||||
|
||||
|
||||
def random_lut_4b(x):
|
||||
"""4-bit random table lookup"""
|
||||
|
||||
# fmt: off
|
||||
table = LookupTable([30, 52, 125, 23, 17, 12, 90, 4, 21, 51, 22, 15, 53, 100, 75, 90])
|
||||
# fmt: on
|
||||
|
||||
return table[x]
|
||||
|
||||
|
||||
def random_lut_5b(x):
|
||||
"""5-bit random table lookup"""
|
||||
|
||||
# fmt: off
|
||||
table = LookupTable(
|
||||
[
|
||||
1, 5, 2, 3, 10, 2, 4, 8, 1, 12, 15, 12, 10, 1, 0, 2,
|
||||
4, 3, 8, 7, 10, 11, 6, 13, 9, 0, 2, 1, 15, 11, 12, 5
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
return table[x]
|
||||
|
||||
|
||||
def random_lut_6b(x):
|
||||
"""6-bit random table lookup"""
|
||||
|
||||
# fmt: off
|
||||
table = LookupTable(
|
||||
[
|
||||
95, 74, 11, 83, 24, 116, 28, 75, 26, 85, 114, 121, 91, 123, 78, 69,
|
||||
72, 115, 67, 5, 39, 11, 120, 88, 56, 43, 74, 16, 72, 85, 103, 92,
|
||||
44, 115, 50, 56, 107, 77, 25, 71, 52, 45, 80, 35, 69, 8, 40, 87,
|
||||
26, 85, 84, 53, 73, 95, 86, 22, 16, 45, 59, 112, 53, 113, 98, 116
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
return table[x]
|
||||
|
||||
|
||||
def random_lut_7b(x):
|
||||
"""7-bit random table lookup"""
|
||||
|
||||
# fmt: off
|
||||
table = LookupTable(
|
||||
[
|
||||
13, 58, 38, 58, 15, 15, 77, 86, 80, 94, 108, 27, 126, 60, 65, 95,
|
||||
50, 79, 22, 97, 38, 60, 25, 48, 73, 112, 27, 45, 88, 20, 67, 17,
|
||||
16, 6, 71, 60, 77, 43, 93, 40, 41, 31, 99, 122, 120, 40, 94, 13,
|
||||
111, 44, 96, 62, 108, 91, 34, 90, 103, 58, 3, 103, 19, 69, 55, 108,
|
||||
0, 111, 113, 0, 0, 73, 22, 52, 81, 2, 88, 76, 36, 121, 97, 121,
|
||||
123, 79, 82, 120, 12, 65, 54, 101, 90, 52, 84, 106, 23, 15, 110, 79,
|
||||
85, 101, 30, 61, 104, 35, 81, 30, 98, 44, 111, 32, 68, 18, 45, 123,
|
||||
84, 80, 68, 27, 31, 38, 126, 61, 51, 7, 49, 37, 63, 114, 22, 18,
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
return table[x]
|
||||
|
||||
|
||||
@@ -352,8 +438,20 @@ def test_compile_function_multiple_outputs(function, input_ranges, list_of_arg_n
|
||||
pytest.param(lambda x: x * 2, ((0, 10),), ["x"]),
|
||||
pytest.param(lambda x: 12 - x, ((0, 10),), ["x"]),
|
||||
pytest.param(lambda x, y: x + y + 8, ((2, 10), (4, 8)), ["x", "y"]),
|
||||
pytest.param(lut, ((0, 127),), ["x"]),
|
||||
pytest.param(small_lut, ((0, 31),), ["x"]),
|
||||
pytest.param(identity_lut_generator(1), ((0, 1),), ["x"]),
|
||||
pytest.param(identity_lut_generator(2), ((0, 3),), ["x"]),
|
||||
pytest.param(identity_lut_generator(3), ((0, 7),), ["x"]),
|
||||
pytest.param(identity_lut_generator(4), ((0, 15),), ["x"]),
|
||||
pytest.param(identity_lut_generator(5), ((0, 31),), ["x"]),
|
||||
pytest.param(identity_lut_generator(6), ((0, 63),), ["x"]),
|
||||
pytest.param(identity_lut_generator(7), ((0, 127),), ["x"]),
|
||||
pytest.param(random_lut_1b, ((0, 1),), ["x"]),
|
||||
pytest.param(random_lut_2b, ((0, 3),), ["x"]),
|
||||
pytest.param(random_lut_3b, ((0, 7),), ["x"]),
|
||||
pytest.param(random_lut_4b, ((0, 15),), ["x"]),
|
||||
pytest.param(random_lut_5b, ((0, 31),), ["x"]),
|
||||
pytest.param(random_lut_6b, ((0, 63),), ["x"]),
|
||||
pytest.param(random_lut_7b, ((0, 127),), ["x"]),
|
||||
pytest.param(small_fused_table, ((0, 31),), ["x"]),
|
||||
],
|
||||
)
|
||||
@@ -523,6 +621,56 @@ def test_compile_and_run_constant_dot_correctness(size, input_range):
|
||||
assert right_circuit.run(*args) == right(*args)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,input_ranges,list_of_arg_names",
|
||||
[
|
||||
pytest.param(identity_lut_generator(1), ((0, 1),), ["x"], id="identity function (1-bit)"),
|
||||
pytest.param(identity_lut_generator(2), ((0, 3),), ["x"], id="identity function (2-bit)"),
|
||||
pytest.param(identity_lut_generator(3), ((0, 7),), ["x"], id="identity function (3-bit)"),
|
||||
pytest.param(identity_lut_generator(4), ((0, 15),), ["x"], id="identity function (4-bit)"),
|
||||
pytest.param(identity_lut_generator(5), ((0, 31),), ["x"], id="identity function (5-bit)"),
|
||||
pytest.param(identity_lut_generator(6), ((0, 63),), ["x"], id="identity function (6-bit)"),
|
||||
pytest.param(identity_lut_generator(7), ((0, 127),), ["x"], id="identity function (7-bit)"),
|
||||
pytest.param(random_lut_1b, ((0, 1),), ["x"], id="random function (1-bit)"),
|
||||
pytest.param(random_lut_2b, ((0, 3),), ["x"], id="random function (2-bit)"),
|
||||
pytest.param(random_lut_3b, ((0, 7),), ["x"], id="random function (3-bit)"),
|
||||
pytest.param(random_lut_4b, ((0, 15),), ["x"], id="random function (4-bit)"),
|
||||
pytest.param(random_lut_5b, ((0, 31),), ["x"], id="random function (5-bit)"),
|
||||
pytest.param(random_lut_6b, ((0, 63),), ["x"], id="random function (6-bit)"),
|
||||
pytest.param(random_lut_7b, ((0, 127),), ["x"], id="random function (7-bit)"),
|
||||
],
|
||||
)
|
||||
def test_compile_and_run_lut_correctness(function, input_ranges, list_of_arg_names):
|
||||
"""Test correctness of results when running a compiled function with LUT"""
|
||||
|
||||
def data_gen(args):
|
||||
for prod in itertools.product(*args):
|
||||
yield prod
|
||||
|
||||
function_parameters = {
|
||||
arg_name: EncryptedScalar(Integer(64, False)) for arg_name in list_of_arg_names
|
||||
}
|
||||
|
||||
compiler_engine = compile_numpy_function(
|
||||
function,
|
||||
function_parameters,
|
||||
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
|
||||
)
|
||||
|
||||
# testing random values
|
||||
for _ in range(10):
|
||||
args = [random.randint(low, high) for (low, high) in input_ranges]
|
||||
check_is_good_execution(compiler_engine, function, args)
|
||||
|
||||
# testing low values
|
||||
args = [low for (low, _) in input_ranges]
|
||||
check_is_good_execution(compiler_engine, function, args)
|
||||
|
||||
# testing high values
|
||||
args = [high for (_, high) in input_ranges]
|
||||
check_is_good_execution(compiler_engine, function, args)
|
||||
|
||||
|
||||
def test_compile_function_with_direct_tlu():
|
||||
"""Test compile_numpy_function_into_op_graph for a program with direct table lookup"""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user