test(direct-tlu): create debugging and compilation tests for direct table lookup

This commit is contained in:
Umut
2021-08-12 16:06:17 +03:00
parent 4afc373a6b
commit 2c3c080923
3 changed files with 135 additions and 8 deletions

View File

@@ -30,19 +30,19 @@ class LookupTable:
self.table = table
self.output_dtype = make_integer_to_hold_ints(table, force_signed=False)
def __getitem__(self, item: Union[int, BaseTracer]):
def __getitem__(self, key: Union[int, BaseTracer]):
# if a tracer is used for indexing,
# we need to create an `ArbitraryFunction` node
# because the result will be determined during the runtime
if isinstance(item, BaseTracer):
if isinstance(key, BaseTracer):
traced_computation = ir.ArbitraryFunction(
input_base_value=item.output,
arbitrary_func=lambda x, table: table[x],
input_base_value=key.output,
arbitrary_func=LookupTable._checked_indexing,
output_dtype=self.output_dtype,
op_kwargs={"table": deepcopy(self.table)},
)
return item.__class__(
inputs=[item],
return key.__class__(
inputs=[key],
traced_computation=traced_computation,
output_index=0,
)
@@ -50,4 +50,14 @@ class LookupTable:
# if not, it means table is indexed with a constant
# thus, the result of the lookup is a constant
# so, we can propagate it directly
return self.table[item]
return LookupTable._checked_indexing(key, self.table)
@staticmethod
def _checked_indexing(x, table):
if x < 0 or x >= len(table):
raise ValueError(
f"Lookup table with {len(table)} entries cannot be indexed with {x} "
f"(you should check your dataset)",
)
return table[x]

View File

@@ -6,6 +6,7 @@ import pytest
from hdk.common.data_types.integers import Integer
from hdk.common.data_types.values import EncryptedValue
from hdk.common.debugging import draw_graph, get_printable_graph
from hdk.common.extensions.table import LookupTable
from hdk.hnumpy.compile import compile_numpy_function
@@ -45,3 +46,37 @@ def test_compile_function_multiple_outputs(function, input_ranges, list_of_arg_n
str_of_the_graph = get_printable_graph(op_graph, show_data_types=True)
print(f"\n{str_of_the_graph}\n")
def test_compile_function_with_direct_tlu():
"""Test compile_numpy_function for a program with direct table lookup"""
table = LookupTable([9, 2, 4, 11])
def function(x):
return x + table[x]
op_graph = compile_numpy_function(
function,
{"x": EncryptedValue(Integer(2, is_signed=False))},
iter([(0,), (1,), (2,), (3,)]),
)
str_of_the_graph = get_printable_graph(op_graph, show_data_types=True)
print(f"\n{str_of_the_graph}\n")
def test_compile_function_with_direct_tlu_overflow():
"""Test compile_numpy_function for a program with direct table lookup overflow"""
table = LookupTable([9, 2, 4, 11])
def function(x):
return table[x]
with pytest.raises(ValueError):
compile_numpy_function(
function,
{"x": EncryptedValue(Integer(3, is_signed=False))},
iter([(0,), (1,), (2,), (3,), (4,), (5,), (6,), (7,)]),
)

View File

@@ -5,8 +5,12 @@ import pytest
from hdk.common.data_types.integers import Integer
from hdk.common.data_types.values import ClearValue, EncryptedValue
from hdk.common.debugging import draw_graph, get_printable_graph
from hdk.common.extensions.table import LookupTable
from hdk.hnumpy import tracing
LOOKUP_TABLE_FROM_2B_TO_4B = LookupTable([9, 2, 4, 11])
LOOKUP_TABLE_FROM_3B_TO_2B = LookupTable([0, 1, 3, 2, 2, 3, 1, 0])
def issue_130_a(x, y):
"""Test case derived from issue #130"""
@@ -143,6 +147,39 @@ def test_hnumpy_print_and_draw_graph(lambda_f, ref_graph_str, x_y):
assert str_of_the_graph == ref_graph_str
@pytest.mark.parametrize(
"lambda_f,params,ref_graph_str",
[
(
lambda x: LOOKUP_TABLE_FROM_2B_TO_4B[x],
{"x": EncryptedValue(Integer(2, is_signed=False))},
"\n%0 = x\n%1 = ArbitraryFunction(0)\nreturn(%1)",
),
(
lambda x: LOOKUP_TABLE_FROM_3B_TO_2B[x + 4],
{"x": EncryptedValue(Integer(2, is_signed=False))},
"\n%0 = x"
"\n%1 = ConstantInput(4)"
"\n%2 = Add(0, 1)"
"\n%3 = ArbitraryFunction(2)"
"\nreturn(%3)",
),
],
)
def test_hnumpy_print_and_draw_graph_with_direct_tlu(lambda_f, params, ref_graph_str):
"Test hnumpy get_printable_graph and draw_graph on graphs with direct table lookup"
graph = tracing.trace_numpy_function(lambda_f, params)
draw_graph(graph, block_until_user_closes_graph=False)
str_of_the_graph = get_printable_graph(graph)
print(f"\nGot {str_of_the_graph}\n")
print(f"\nExp {ref_graph_str}\n")
assert str_of_the_graph == ref_graph_str
# Remark that the bitwidths are not particularly correct (eg, a MUL of a 17b times 23b
# returning 23b), since they are replaced later by the real bitwidths computed on the
# dataset
@@ -174,7 +211,7 @@ def test_hnumpy_print_and_draw_graph(lambda_f, ref_graph_str, x_y):
],
)
def test_hnumpy_print_with_show_data_types(lambda_f, x_y, ref_graph_str):
"Test hnumpy get_printable_graph with show_data_types"
"""Test hnumpy get_printable_graph with show_data_types"""
x, y = x_y
graph = tracing.trace_numpy_function(lambda_f, {"x": x, "y": y})
@@ -184,3 +221,48 @@ def test_hnumpy_print_with_show_data_types(lambda_f, x_y, ref_graph_str):
print(f"\nExp {ref_graph_str}\n")
assert str_of_the_graph == ref_graph_str
@pytest.mark.parametrize(
"lambda_f,params,ref_graph_str",
[
(
lambda x: LOOKUP_TABLE_FROM_2B_TO_4B[x],
{"x": EncryptedValue(Integer(2, is_signed=False))},
"\n%0 = x # Integer<unsigned, 2 bits>"
"\n%1 = ArbitraryFunction(0) # Integer<unsigned, 4 bits>"
"\nreturn(%1)",
),
(
lambda x: LOOKUP_TABLE_FROM_3B_TO_2B[x + 4],
{"x": EncryptedValue(Integer(2, is_signed=False))},
"\n%0 = x # Integer<unsigned, 2 bits>"
"\n%1 = ConstantInput(4) # Integer<unsigned, 3 bits>"
"\n%2 = Add(0, 1) # Integer<unsigned, 3 bits>"
"\n%3 = ArbitraryFunction(2) # Integer<unsigned, 2 bits>"
"\nreturn(%3)",
),
(
lambda x: LOOKUP_TABLE_FROM_2B_TO_4B[LOOKUP_TABLE_FROM_3B_TO_2B[x + 4]],
{"x": EncryptedValue(Integer(2, is_signed=False))},
"\n%0 = x # Integer<unsigned, 2 bits>"
"\n%1 = ConstantInput(4) # Integer<unsigned, 3 bits>"
"\n%2 = Add(0, 1) # Integer<unsigned, 3 bits>"
"\n%3 = ArbitraryFunction(2) # Integer<unsigned, 2 bits>"
"\n%4 = ArbitraryFunction(3) # Integer<unsigned, 4 bits>"
"\nreturn(%4)",
),
],
)
def test_hnumpy_print_with_show_data_types_with_direct_tlu(lambda_f, params, ref_graph_str):
"""Test hnumpy get_printable_graph with show_data_types on graphs with direct table lookup"""
graph = tracing.trace_numpy_function(lambda_f, params)
draw_graph(graph, block_until_user_closes_graph=False)
str_of_the_graph = get_printable_graph(graph, show_data_types=True)
print(f"\nGot {str_of_the_graph}\n")
print(f"\nExp {ref_graph_str}\n")
assert str_of_the_graph == ref_graph_str