mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
test(direct-tlu): create debugging and compilation tests for direct table lookup
This commit is contained in:
@@ -30,19 +30,19 @@ class LookupTable:
|
||||
self.table = table
|
||||
self.output_dtype = make_integer_to_hold_ints(table, force_signed=False)
|
||||
|
||||
def __getitem__(self, item: Union[int, BaseTracer]):
|
||||
def __getitem__(self, key: Union[int, BaseTracer]):
|
||||
# if a tracer is used for indexing,
|
||||
# we need to create an `ArbitraryFunction` node
|
||||
# because the result will be determined during the runtime
|
||||
if isinstance(item, BaseTracer):
|
||||
if isinstance(key, BaseTracer):
|
||||
traced_computation = ir.ArbitraryFunction(
|
||||
input_base_value=item.output,
|
||||
arbitrary_func=lambda x, table: table[x],
|
||||
input_base_value=key.output,
|
||||
arbitrary_func=LookupTable._checked_indexing,
|
||||
output_dtype=self.output_dtype,
|
||||
op_kwargs={"table": deepcopy(self.table)},
|
||||
)
|
||||
return item.__class__(
|
||||
inputs=[item],
|
||||
return key.__class__(
|
||||
inputs=[key],
|
||||
traced_computation=traced_computation,
|
||||
output_index=0,
|
||||
)
|
||||
@@ -50,4 +50,14 @@ class LookupTable:
|
||||
# if not, it means table is indexed with a constant
|
||||
# thus, the result of the lookup is a constant
|
||||
# so, we can propagate it directly
|
||||
return self.table[item]
|
||||
return LookupTable._checked_indexing(key, self.table)
|
||||
|
||||
@staticmethod
|
||||
def _checked_indexing(x, table):
|
||||
if x < 0 or x >= len(table):
|
||||
raise ValueError(
|
||||
f"Lookup table with {len(table)} entries cannot be indexed with {x} "
|
||||
f"(you should check your dataset)",
|
||||
)
|
||||
|
||||
return table[x]
|
||||
|
||||
@@ -6,6 +6,7 @@ import pytest
|
||||
from hdk.common.data_types.integers import Integer
|
||||
from hdk.common.data_types.values import EncryptedValue
|
||||
from hdk.common.debugging import draw_graph, get_printable_graph
|
||||
from hdk.common.extensions.table import LookupTable
|
||||
from hdk.hnumpy.compile import compile_numpy_function
|
||||
|
||||
|
||||
@@ -45,3 +46,37 @@ def test_compile_function_multiple_outputs(function, input_ranges, list_of_arg_n
|
||||
|
||||
str_of_the_graph = get_printable_graph(op_graph, show_data_types=True)
|
||||
print(f"\n{str_of_the_graph}\n")
|
||||
|
||||
|
||||
def test_compile_function_with_direct_tlu():
|
||||
"""Test compile_numpy_function for a program with direct table lookup"""
|
||||
|
||||
table = LookupTable([9, 2, 4, 11])
|
||||
|
||||
def function(x):
|
||||
return x + table[x]
|
||||
|
||||
op_graph = compile_numpy_function(
|
||||
function,
|
||||
{"x": EncryptedValue(Integer(2, is_signed=False))},
|
||||
iter([(0,), (1,), (2,), (3,)]),
|
||||
)
|
||||
|
||||
str_of_the_graph = get_printable_graph(op_graph, show_data_types=True)
|
||||
print(f"\n{str_of_the_graph}\n")
|
||||
|
||||
|
||||
def test_compile_function_with_direct_tlu_overflow():
|
||||
"""Test compile_numpy_function for a program with direct table lookup overflow"""
|
||||
|
||||
table = LookupTable([9, 2, 4, 11])
|
||||
|
||||
def function(x):
|
||||
return table[x]
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
compile_numpy_function(
|
||||
function,
|
||||
{"x": EncryptedValue(Integer(3, is_signed=False))},
|
||||
iter([(0,), (1,), (2,), (3,), (4,), (5,), (6,), (7,)]),
|
||||
)
|
||||
|
||||
@@ -5,8 +5,12 @@ import pytest
|
||||
from hdk.common.data_types.integers import Integer
|
||||
from hdk.common.data_types.values import ClearValue, EncryptedValue
|
||||
from hdk.common.debugging import draw_graph, get_printable_graph
|
||||
from hdk.common.extensions.table import LookupTable
|
||||
from hdk.hnumpy import tracing
|
||||
|
||||
LOOKUP_TABLE_FROM_2B_TO_4B = LookupTable([9, 2, 4, 11])
|
||||
LOOKUP_TABLE_FROM_3B_TO_2B = LookupTable([0, 1, 3, 2, 2, 3, 1, 0])
|
||||
|
||||
|
||||
def issue_130_a(x, y):
|
||||
"""Test case derived from issue #130"""
|
||||
@@ -143,6 +147,39 @@ def test_hnumpy_print_and_draw_graph(lambda_f, ref_graph_str, x_y):
|
||||
assert str_of_the_graph == ref_graph_str
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"lambda_f,params,ref_graph_str",
|
||||
[
|
||||
(
|
||||
lambda x: LOOKUP_TABLE_FROM_2B_TO_4B[x],
|
||||
{"x": EncryptedValue(Integer(2, is_signed=False))},
|
||||
"\n%0 = x\n%1 = ArbitraryFunction(0)\nreturn(%1)",
|
||||
),
|
||||
(
|
||||
lambda x: LOOKUP_TABLE_FROM_3B_TO_2B[x + 4],
|
||||
{"x": EncryptedValue(Integer(2, is_signed=False))},
|
||||
"\n%0 = x"
|
||||
"\n%1 = ConstantInput(4)"
|
||||
"\n%2 = Add(0, 1)"
|
||||
"\n%3 = ArbitraryFunction(2)"
|
||||
"\nreturn(%3)",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_hnumpy_print_and_draw_graph_with_direct_tlu(lambda_f, params, ref_graph_str):
|
||||
"Test hnumpy get_printable_graph and draw_graph on graphs with direct table lookup"
|
||||
graph = tracing.trace_numpy_function(lambda_f, params)
|
||||
|
||||
draw_graph(graph, block_until_user_closes_graph=False)
|
||||
|
||||
str_of_the_graph = get_printable_graph(graph)
|
||||
|
||||
print(f"\nGot {str_of_the_graph}\n")
|
||||
print(f"\nExp {ref_graph_str}\n")
|
||||
|
||||
assert str_of_the_graph == ref_graph_str
|
||||
|
||||
|
||||
# Remark that the bitwidths are not particularly correct (eg, a MUL of a 17b times 23b
|
||||
# returning 23b), since they are replaced later by the real bitwidths computed on the
|
||||
# dataset
|
||||
@@ -174,7 +211,7 @@ def test_hnumpy_print_and_draw_graph(lambda_f, ref_graph_str, x_y):
|
||||
],
|
||||
)
|
||||
def test_hnumpy_print_with_show_data_types(lambda_f, x_y, ref_graph_str):
|
||||
"Test hnumpy get_printable_graph with show_data_types"
|
||||
"""Test hnumpy get_printable_graph with show_data_types"""
|
||||
x, y = x_y
|
||||
graph = tracing.trace_numpy_function(lambda_f, {"x": x, "y": y})
|
||||
|
||||
@@ -184,3 +221,48 @@ def test_hnumpy_print_with_show_data_types(lambda_f, x_y, ref_graph_str):
|
||||
print(f"\nExp {ref_graph_str}\n")
|
||||
|
||||
assert str_of_the_graph == ref_graph_str
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"lambda_f,params,ref_graph_str",
|
||||
[
|
||||
(
|
||||
lambda x: LOOKUP_TABLE_FROM_2B_TO_4B[x],
|
||||
{"x": EncryptedValue(Integer(2, is_signed=False))},
|
||||
"\n%0 = x # Integer<unsigned, 2 bits>"
|
||||
"\n%1 = ArbitraryFunction(0) # Integer<unsigned, 4 bits>"
|
||||
"\nreturn(%1)",
|
||||
),
|
||||
(
|
||||
lambda x: LOOKUP_TABLE_FROM_3B_TO_2B[x + 4],
|
||||
{"x": EncryptedValue(Integer(2, is_signed=False))},
|
||||
"\n%0 = x # Integer<unsigned, 2 bits>"
|
||||
"\n%1 = ConstantInput(4) # Integer<unsigned, 3 bits>"
|
||||
"\n%2 = Add(0, 1) # Integer<unsigned, 3 bits>"
|
||||
"\n%3 = ArbitraryFunction(2) # Integer<unsigned, 2 bits>"
|
||||
"\nreturn(%3)",
|
||||
),
|
||||
(
|
||||
lambda x: LOOKUP_TABLE_FROM_2B_TO_4B[LOOKUP_TABLE_FROM_3B_TO_2B[x + 4]],
|
||||
{"x": EncryptedValue(Integer(2, is_signed=False))},
|
||||
"\n%0 = x # Integer<unsigned, 2 bits>"
|
||||
"\n%1 = ConstantInput(4) # Integer<unsigned, 3 bits>"
|
||||
"\n%2 = Add(0, 1) # Integer<unsigned, 3 bits>"
|
||||
"\n%3 = ArbitraryFunction(2) # Integer<unsigned, 2 bits>"
|
||||
"\n%4 = ArbitraryFunction(3) # Integer<unsigned, 4 bits>"
|
||||
"\nreturn(%4)",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_hnumpy_print_with_show_data_types_with_direct_tlu(lambda_f, params, ref_graph_str):
|
||||
"""Test hnumpy get_printable_graph with show_data_types on graphs with direct table lookup"""
|
||||
graph = tracing.trace_numpy_function(lambda_f, params)
|
||||
|
||||
draw_graph(graph, block_until_user_closes_graph=False)
|
||||
|
||||
str_of_the_graph = get_printable_graph(graph, show_data_types=True)
|
||||
|
||||
print(f"\nGot {str_of_the_graph}\n")
|
||||
print(f"\nExp {ref_graph_str}\n")
|
||||
|
||||
assert str_of_the_graph == ref_graph_str
|
||||
|
||||
Reference in New Issue
Block a user