test(execution): add table lookup correctness tests

This commit is contained in:
Umut
2021-10-15 14:50:19 +03:00
parent e164671608
commit 8b2efb8869

View File

@@ -22,15 +22,101 @@ def no_fuse_unhandled(x, y):
return intermediate.astype(numpy.int32)
def lut(x):
def identity_lut_generator(n):
"""Test lookup table"""
table = LookupTable(list(range(128)))
return lambda x: LookupTable(list(range(2 ** n)))[x]
def random_lut_1b(x):
"""1-bit random table lookup"""
# fmt: off
table = LookupTable([10, 12])
# fmt: on
return table[x]
def small_lut(x):
"""Test lookup table with small size and output"""
table = LookupTable(list(range(32)))
def random_lut_2b(x):
"""2-bit random table lookup"""
# fmt: off
table = LookupTable([3, 8, 22, 127])
# fmt: on
return table[x]
def random_lut_3b(x):
"""3-bit random table lookup"""
# fmt: off
table = LookupTable([30, 52, 125, 23, 17, 12, 90, 4])
# fmt: on
return table[x]
def random_lut_4b(x):
"""4-bit random table lookup"""
# fmt: off
table = LookupTable([30, 52, 125, 23, 17, 12, 90, 4, 21, 51, 22, 15, 53, 100, 75, 90])
# fmt: on
return table[x]
def random_lut_5b(x):
"""5-bit random table lookup"""
# fmt: off
table = LookupTable(
[
1, 5, 2, 3, 10, 2, 4, 8, 1, 12, 15, 12, 10, 1, 0, 2,
4, 3, 8, 7, 10, 11, 6, 13, 9, 0, 2, 1, 15, 11, 12, 5
]
)
# fmt: on
return table[x]
def random_lut_6b(x):
"""6-bit random table lookup"""
# fmt: off
table = LookupTable(
[
95, 74, 11, 83, 24, 116, 28, 75, 26, 85, 114, 121, 91, 123, 78, 69,
72, 115, 67, 5, 39, 11, 120, 88, 56, 43, 74, 16, 72, 85, 103, 92,
44, 115, 50, 56, 107, 77, 25, 71, 52, 45, 80, 35, 69, 8, 40, 87,
26, 85, 84, 53, 73, 95, 86, 22, 16, 45, 59, 112, 53, 113, 98, 116
]
)
# fmt: on
return table[x]
def random_lut_7b(x):
"""7-bit random table lookup"""
# fmt: off
table = LookupTable(
[
13, 58, 38, 58, 15, 15, 77, 86, 80, 94, 108, 27, 126, 60, 65, 95,
50, 79, 22, 97, 38, 60, 25, 48, 73, 112, 27, 45, 88, 20, 67, 17,
16, 6, 71, 60, 77, 43, 93, 40, 41, 31, 99, 122, 120, 40, 94, 13,
111, 44, 96, 62, 108, 91, 34, 90, 103, 58, 3, 103, 19, 69, 55, 108,
0, 111, 113, 0, 0, 73, 22, 52, 81, 2, 88, 76, 36, 121, 97, 121,
123, 79, 82, 120, 12, 65, 54, 101, 90, 52, 84, 106, 23, 15, 110, 79,
85, 101, 30, 61, 104, 35, 81, 30, 98, 44, 111, 32, 68, 18, 45, 123,
84, 80, 68, 27, 31, 38, 126, 61, 51, 7, 49, 37, 63, 114, 22, 18,
]
)
# fmt: on
return table[x]
@@ -352,8 +438,20 @@ def test_compile_function_multiple_outputs(function, input_ranges, list_of_arg_n
pytest.param(lambda x: x * 2, ((0, 10),), ["x"]),
pytest.param(lambda x: 12 - x, ((0, 10),), ["x"]),
pytest.param(lambda x, y: x + y + 8, ((2, 10), (4, 8)), ["x", "y"]),
pytest.param(lut, ((0, 127),), ["x"]),
pytest.param(small_lut, ((0, 31),), ["x"]),
pytest.param(identity_lut_generator(1), ((0, 1),), ["x"]),
pytest.param(identity_lut_generator(2), ((0, 3),), ["x"]),
pytest.param(identity_lut_generator(3), ((0, 7),), ["x"]),
pytest.param(identity_lut_generator(4), ((0, 15),), ["x"]),
pytest.param(identity_lut_generator(5), ((0, 31),), ["x"]),
pytest.param(identity_lut_generator(6), ((0, 63),), ["x"]),
pytest.param(identity_lut_generator(7), ((0, 127),), ["x"]),
pytest.param(random_lut_1b, ((0, 1),), ["x"]),
pytest.param(random_lut_2b, ((0, 3),), ["x"]),
pytest.param(random_lut_3b, ((0, 7),), ["x"]),
pytest.param(random_lut_4b, ((0, 15),), ["x"]),
pytest.param(random_lut_5b, ((0, 31),), ["x"]),
pytest.param(random_lut_6b, ((0, 63),), ["x"]),
pytest.param(random_lut_7b, ((0, 127),), ["x"]),
pytest.param(small_fused_table, ((0, 31),), ["x"]),
],
)
@@ -523,6 +621,56 @@ def test_compile_and_run_constant_dot_correctness(size, input_range):
assert right_circuit.run(*args) == right(*args)
@pytest.mark.parametrize(
"function,input_ranges,list_of_arg_names",
[
pytest.param(identity_lut_generator(1), ((0, 1),), ["x"], id="identity function (1-bit)"),
pytest.param(identity_lut_generator(2), ((0, 3),), ["x"], id="identity function (2-bit)"),
pytest.param(identity_lut_generator(3), ((0, 7),), ["x"], id="identity function (3-bit)"),
pytest.param(identity_lut_generator(4), ((0, 15),), ["x"], id="identity function (4-bit)"),
pytest.param(identity_lut_generator(5), ((0, 31),), ["x"], id="identity function (5-bit)"),
pytest.param(identity_lut_generator(6), ((0, 63),), ["x"], id="identity function (6-bit)"),
pytest.param(identity_lut_generator(7), ((0, 127),), ["x"], id="identity function (7-bit)"),
pytest.param(random_lut_1b, ((0, 1),), ["x"], id="random function (1-bit)"),
pytest.param(random_lut_2b, ((0, 3),), ["x"], id="random function (2-bit)"),
pytest.param(random_lut_3b, ((0, 7),), ["x"], id="random function (3-bit)"),
pytest.param(random_lut_4b, ((0, 15),), ["x"], id="random function (4-bit)"),
pytest.param(random_lut_5b, ((0, 31),), ["x"], id="random function (5-bit)"),
pytest.param(random_lut_6b, ((0, 63),), ["x"], id="random function (6-bit)"),
pytest.param(random_lut_7b, ((0, 127),), ["x"], id="random function (7-bit)"),
],
)
def test_compile_and_run_lut_correctness(function, input_ranges, list_of_arg_names):
"""Test correctness of results when running a compiled function with LUT"""
def data_gen(args):
for prod in itertools.product(*args):
yield prod
function_parameters = {
arg_name: EncryptedScalar(Integer(64, False)) for arg_name in list_of_arg_names
}
compiler_engine = compile_numpy_function(
function,
function_parameters,
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
)
# testing random values
for _ in range(10):
args = [random.randint(low, high) for (low, high) in input_ranges]
check_is_good_execution(compiler_engine, function, args)
# testing low values
args = [low for (low, _) in input_ranges]
check_is_good_execution(compiler_engine, function, args)
# testing high values
args = [high for (_, high) in input_ranges]
check_is_good_execution(compiler_engine, function, args)
def test_compile_function_with_direct_tlu():
"""Test compile_numpy_function_into_op_graph for a program with direct table lookup"""