From 2c3c0809232cf682e7983201678d18a10996250d Mon Sep 17 00:00:00 2001 From: Umut Date: Thu, 12 Aug 2021 16:06:17 +0300 Subject: [PATCH] test(direct-tlu): create debugging and compilation tests for direct table lookup --- hdk/common/extensions/table.py | 24 +++++++--- tests/hnumpy/test_compile.py | 35 ++++++++++++++ tests/hnumpy/test_debugging.py | 84 +++++++++++++++++++++++++++++++++- 3 files changed, 135 insertions(+), 8 deletions(-) diff --git a/hdk/common/extensions/table.py b/hdk/common/extensions/table.py index 38995f07b..0ef1e0041 100644 --- a/hdk/common/extensions/table.py +++ b/hdk/common/extensions/table.py @@ -30,19 +30,19 @@ class LookupTable: self.table = table self.output_dtype = make_integer_to_hold_ints(table, force_signed=False) - def __getitem__(self, item: Union[int, BaseTracer]): + def __getitem__(self, key: Union[int, BaseTracer]): # if a tracer is used for indexing, # we need to create an `ArbitraryFunction` node # because the result will be determined during the runtime - if isinstance(item, BaseTracer): + if isinstance(key, BaseTracer): traced_computation = ir.ArbitraryFunction( - input_base_value=item.output, - arbitrary_func=lambda x, table: table[x], + input_base_value=key.output, + arbitrary_func=LookupTable._checked_indexing, output_dtype=self.output_dtype, op_kwargs={"table": deepcopy(self.table)}, ) - return item.__class__( - inputs=[item], + return key.__class__( + inputs=[key], traced_computation=traced_computation, output_index=0, ) @@ -50,4 +50,14 @@ class LookupTable: # if not, it means table is indexed with a constant # thus, the result of the lookup is a constant # so, we can propagate it directly - return self.table[item] + return LookupTable._checked_indexing(key, self.table) + + @staticmethod + def _checked_indexing(x, table): + if x < 0 or x >= len(table): + raise ValueError( + f"Lookup table with {len(table)} entries cannot be indexed with {x} " + f"(you should check your dataset)", + ) + + return table[x] diff --git a/tests/hnumpy/test_compile.py b/tests/hnumpy/test_compile.py index fb7a76147..e6f4b8996 100644 --- a/tests/hnumpy/test_compile.py +++ b/tests/hnumpy/test_compile.py @@ -6,6 +6,7 @@ import pytest from hdk.common.data_types.integers import Integer from hdk.common.data_types.values import EncryptedValue from hdk.common.debugging import draw_graph, get_printable_graph +from hdk.common.extensions.table import LookupTable from hdk.hnumpy.compile import compile_numpy_function @@ -45,3 +46,37 @@ def test_compile_function_multiple_outputs(function, input_ranges, list_of_arg_n str_of_the_graph = get_printable_graph(op_graph, show_data_types=True) print(f"\n{str_of_the_graph}\n") + + +def test_compile_function_with_direct_tlu(): + """Test compile_numpy_function for a program with direct table lookup""" + + table = LookupTable([9, 2, 4, 11]) + + def function(x): + return x + table[x] + + op_graph = compile_numpy_function( + function, + {"x": EncryptedValue(Integer(2, is_signed=False))}, + iter([(0,), (1,), (2,), (3,)]), + ) + + str_of_the_graph = get_printable_graph(op_graph, show_data_types=True) + print(f"\n{str_of_the_graph}\n") + + +def test_compile_function_with_direct_tlu_overflow(): + """Test compile_numpy_function for a program with direct table lookup overflow""" + + table = LookupTable([9, 2, 4, 11]) + + def function(x): + return table[x] + + with pytest.raises(ValueError): + compile_numpy_function( + function, + {"x": EncryptedValue(Integer(3, is_signed=False))}, + iter([(0,), (1,), (2,), (3,), (4,), (5,), (6,), (7,)]), + ) diff --git a/tests/hnumpy/test_debugging.py b/tests/hnumpy/test_debugging.py index 6a51a4374..9e5972a1b 100644 --- a/tests/hnumpy/test_debugging.py +++ b/tests/hnumpy/test_debugging.py @@ -5,8 +5,12 @@ import pytest from hdk.common.data_types.integers import Integer from hdk.common.data_types.values import ClearValue, EncryptedValue from hdk.common.debugging import draw_graph, get_printable_graph +from hdk.common.extensions.table import LookupTable from hdk.hnumpy import tracing +LOOKUP_TABLE_FROM_2B_TO_4B = LookupTable([9, 2, 4, 11]) +LOOKUP_TABLE_FROM_3B_TO_2B = LookupTable([0, 1, 3, 2, 2, 3, 1, 0]) + def issue_130_a(x, y): """Test case derived from issue #130""" @@ -143,6 +147,39 @@ def test_hnumpy_print_and_draw_graph(lambda_f, ref_graph_str, x_y): assert str_of_the_graph == ref_graph_str +@pytest.mark.parametrize( + "lambda_f,params,ref_graph_str", + [ + ( + lambda x: LOOKUP_TABLE_FROM_2B_TO_4B[x], + {"x": EncryptedValue(Integer(2, is_signed=False))}, + "\n%0 = x\n%1 = ArbitraryFunction(0)\nreturn(%1)", + ), + ( + lambda x: LOOKUP_TABLE_FROM_3B_TO_2B[x + 4], + {"x": EncryptedValue(Integer(2, is_signed=False))}, + "\n%0 = x" + "\n%1 = ConstantInput(4)" + "\n%2 = Add(0, 1)" + "\n%3 = ArbitraryFunction(2)" + "\nreturn(%3)", + ), + ], +) +def test_hnumpy_print_and_draw_graph_with_direct_tlu(lambda_f, params, ref_graph_str): + "Test hnumpy get_printable_graph and draw_graph on graphs with direct table lookup" + graph = tracing.trace_numpy_function(lambda_f, params) + + draw_graph(graph, block_until_user_closes_graph=False) + + str_of_the_graph = get_printable_graph(graph) + + print(f"\nGot {str_of_the_graph}\n") + print(f"\nExp {ref_graph_str}\n") + + assert str_of_the_graph == ref_graph_str + + # Remark that the bitwidths are not particularly correct (eg, a MUL of a 17b times 23b # returning 23b), since they are replaced later by the real bitwidths computed on the # dataset @@ -174,7 +211,7 @@ def test_hnumpy_print_and_draw_graph(lambda_f, ref_graph_str, x_y): ], ) def test_hnumpy_print_with_show_data_types(lambda_f, x_y, ref_graph_str): - "Test hnumpy get_printable_graph with show_data_types" + """Test hnumpy get_printable_graph with show_data_types""" x, y = x_y graph = tracing.trace_numpy_function(lambda_f, {"x": x, "y": y}) @@ -184,3 +221,48 @@ def test_hnumpy_print_with_show_data_types(lambda_f, x_y, ref_graph_str): print(f"\nExp {ref_graph_str}\n") assert str_of_the_graph == ref_graph_str + + +@pytest.mark.parametrize( + "lambda_f,params,ref_graph_str", + [ + ( + lambda x: LOOKUP_TABLE_FROM_2B_TO_4B[x], + {"x": EncryptedValue(Integer(2, is_signed=False))}, + "\n%0 = x # Integer" + "\n%1 = ArbitraryFunction(0) # Integer" + "\nreturn(%1)", + ), + ( + lambda x: LOOKUP_TABLE_FROM_3B_TO_2B[x + 4], + {"x": EncryptedValue(Integer(2, is_signed=False))}, + "\n%0 = x # Integer" + "\n%1 = ConstantInput(4) # Integer" + "\n%2 = Add(0, 1) # Integer" + "\n%3 = ArbitraryFunction(2) # Integer" + "\nreturn(%3)", + ), + ( + lambda x: LOOKUP_TABLE_FROM_2B_TO_4B[LOOKUP_TABLE_FROM_3B_TO_2B[x + 4]], + {"x": EncryptedValue(Integer(2, is_signed=False))}, + "\n%0 = x # Integer" + "\n%1 = ConstantInput(4) # Integer" + "\n%2 = Add(0, 1) # Integer" + "\n%3 = ArbitraryFunction(2) # Integer" + "\n%4 = ArbitraryFunction(3) # Integer" + "\nreturn(%4)", + ), + ], +) +def test_hnumpy_print_with_show_data_types_with_direct_tlu(lambda_f, params, ref_graph_str): + """Test hnumpy get_printable_graph with show_data_types on graphs with direct table lookup""" + graph = tracing.trace_numpy_function(lambda_f, params) + + draw_graph(graph, block_until_user_closes_graph=False) + + str_of_the_graph = get_printable_graph(graph, show_data_types=True) + + print(f"\nGot {str_of_the_graph}\n") + print(f"\nExp {ref_graph_str}\n") + + assert str_of_the_graph == ref_graph_str