From 8b2efb8869fc1c3a05f4c3e88574f183c282a007 Mon Sep 17 00:00:00 2001 From: Umut Date: Fri, 15 Oct 2021 14:50:19 +0300 Subject: [PATCH] test(execution): add table lookup correctness tests --- tests/numpy/test_compile.py | 162 ++++++++++++++++++++++++++++++++++-- 1 file changed, 155 insertions(+), 7 deletions(-) diff --git a/tests/numpy/test_compile.py b/tests/numpy/test_compile.py index 6801a9754..dab66562e 100644 --- a/tests/numpy/test_compile.py +++ b/tests/numpy/test_compile.py @@ -22,15 +22,101 @@ def no_fuse_unhandled(x, y): return intermediate.astype(numpy.int32) -def lut(x): +def identity_lut_generator(n): """Test lookup table""" - table = LookupTable(list(range(128))) + return lambda x: LookupTable(list(range(2 ** n)))[x] + + +def random_lut_1b(x): + """1-bit random table lookup""" + + # fmt: off + table = LookupTable([10, 12]) + # fmt: on + return table[x] -def small_lut(x): - """Test lookup table with small size and output""" - table = LookupTable(list(range(32))) +def random_lut_2b(x): + """2-bit random table lookup""" + + # fmt: off + table = LookupTable([3, 8, 22, 127]) + # fmt: on + + return table[x] + + +def random_lut_3b(x): + """3-bit random table lookup""" + + # fmt: off + table = LookupTable([30, 52, 125, 23, 17, 12, 90, 4]) + # fmt: on + + return table[x] + + +def random_lut_4b(x): + """4-bit random table lookup""" + + # fmt: off + table = LookupTable([30, 52, 125, 23, 17, 12, 90, 4, 21, 51, 22, 15, 53, 100, 75, 90]) + # fmt: on + + return table[x] + + +def random_lut_5b(x): + """5-bit random table lookup""" + + # fmt: off + table = LookupTable( + [ + 1, 5, 2, 3, 10, 2, 4, 8, 1, 12, 15, 12, 10, 1, 0, 2, + 4, 3, 8, 7, 10, 11, 6, 13, 9, 0, 2, 1, 15, 11, 12, 5 + ] + ) + # fmt: on + + return table[x] + + +def random_lut_6b(x): + """6-bit random table lookup""" + + # fmt: off + table = LookupTable( + [ + 95, 74, 11, 83, 24, 116, 28, 75, 26, 85, 114, 121, 91, 123, 78, 69, + 72, 115, 67, 5, 39, 11, 120, 88, 56, 43, 74, 16, 72, 85, 103, 92, + 44, 115, 50, 56, 107, 77, 25, 71, 52, 45, 80, 35, 69, 8, 40, 87, + 26, 85, 84, 53, 73, 95, 86, 22, 16, 45, 59, 112, 53, 113, 98, 116 + ] + ) + # fmt: on + + return table[x] + + +def random_lut_7b(x): + """7-bit random table lookup""" + + # fmt: off + table = LookupTable( + [ + 13, 58, 38, 58, 15, 15, 77, 86, 80, 94, 108, 27, 126, 60, 65, 95, + 50, 79, 22, 97, 38, 60, 25, 48, 73, 112, 27, 45, 88, 20, 67, 17, + 16, 6, 71, 60, 77, 43, 93, 40, 41, 31, 99, 122, 120, 40, 94, 13, + 111, 44, 96, 62, 108, 91, 34, 90, 103, 58, 3, 103, 19, 69, 55, 108, + 0, 111, 113, 0, 0, 73, 22, 52, 81, 2, 88, 76, 36, 121, 97, 121, + 123, 79, 82, 120, 12, 65, 54, 101, 90, 52, 84, 106, 23, 15, 110, 79, + 85, 101, 30, 61, 104, 35, 81, 30, 98, 44, 111, 32, 68, 18, 45, 123, + 84, 80, 68, 27, 31, 38, 126, 61, 51, 7, 49, 37, 63, 114, 22, 18, + ] + ) + # fmt: on + return table[x] @@ -352,8 +438,20 @@ def test_compile_function_multiple_outputs(function, input_ranges, list_of_arg_n pytest.param(lambda x: x * 2, ((0, 10),), ["x"]), pytest.param(lambda x: 12 - x, ((0, 10),), ["x"]), pytest.param(lambda x, y: x + y + 8, ((2, 10), (4, 8)), ["x", "y"]), - pytest.param(lut, ((0, 127),), ["x"]), - pytest.param(small_lut, ((0, 31),), ["x"]), + pytest.param(identity_lut_generator(1), ((0, 1),), ["x"]), + pytest.param(identity_lut_generator(2), ((0, 3),), ["x"]), + pytest.param(identity_lut_generator(3), ((0, 7),), ["x"]), + pytest.param(identity_lut_generator(4), ((0, 15),), ["x"]), + pytest.param(identity_lut_generator(5), ((0, 31),), ["x"]), + pytest.param(identity_lut_generator(6), ((0, 63),), ["x"]), + pytest.param(identity_lut_generator(7), ((0, 127),), ["x"]), + pytest.param(random_lut_1b, ((0, 1),), ["x"]), + pytest.param(random_lut_2b, ((0, 3),), ["x"]), + pytest.param(random_lut_3b, ((0, 7),), ["x"]), + pytest.param(random_lut_4b, ((0, 15),), ["x"]), + pytest.param(random_lut_5b, ((0, 31),), ["x"]), + pytest.param(random_lut_6b, ((0, 63),), ["x"]), + pytest.param(random_lut_7b, ((0, 127),), ["x"]), pytest.param(small_fused_table, ((0, 31),), ["x"]), ], ) @@ -523,6 +621,56 @@ def test_compile_and_run_constant_dot_correctness(size, input_range): assert right_circuit.run(*args) == right(*args) +@pytest.mark.parametrize( + "function,input_ranges,list_of_arg_names", + [ + pytest.param(identity_lut_generator(1), ((0, 1),), ["x"], id="identity function (1-bit)"), + pytest.param(identity_lut_generator(2), ((0, 3),), ["x"], id="identity function (2-bit)"), + pytest.param(identity_lut_generator(3), ((0, 7),), ["x"], id="identity function (3-bit)"), + pytest.param(identity_lut_generator(4), ((0, 15),), ["x"], id="identity function (4-bit)"), + pytest.param(identity_lut_generator(5), ((0, 31),), ["x"], id="identity function (5-bit)"), + pytest.param(identity_lut_generator(6), ((0, 63),), ["x"], id="identity function (6-bit)"), + pytest.param(identity_lut_generator(7), ((0, 127),), ["x"], id="identity function (7-bit)"), + pytest.param(random_lut_1b, ((0, 1),), ["x"], id="random function (1-bit)"), + pytest.param(random_lut_2b, ((0, 3),), ["x"], id="random function (2-bit)"), + pytest.param(random_lut_3b, ((0, 7),), ["x"], id="random function (3-bit)"), + pytest.param(random_lut_4b, ((0, 15),), ["x"], id="random function (4-bit)"), + pytest.param(random_lut_5b, ((0, 31),), ["x"], id="random function (5-bit)"), + pytest.param(random_lut_6b, ((0, 63),), ["x"], id="random function (6-bit)"), + pytest.param(random_lut_7b, ((0, 127),), ["x"], id="random function (7-bit)"), + ], +) +def test_compile_and_run_lut_correctness(function, input_ranges, list_of_arg_names): + """Test correctness of results when running a compiled function with LUT""" + + def data_gen(args): + for prod in itertools.product(*args): + yield prod + + function_parameters = { + arg_name: EncryptedScalar(Integer(64, False)) for arg_name in list_of_arg_names + } + + compiler_engine = compile_numpy_function( + function, + function_parameters, + data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)), + ) + + # testing random values + for _ in range(10): + args = [random.randint(low, high) for (low, high) in input_ranges] + check_is_good_execution(compiler_engine, function, args) + + # testing low values + args = [low for (low, _) in input_ranges] + check_is_good_execution(compiler_engine, function, args) + + # testing high values + args = [high for (_, high) in input_ranges] + check_is_good_execution(compiler_engine, function, args) + + def test_compile_function_with_direct_tlu(): """Test compile_numpy_function_into_op_graph for a program with direct table lookup"""