From 9a3e15e89ab6d99ee1ade0c86b103adee79b7a59 Mon Sep 17 00:00:00 2001 From: Benoit Chevallier-Mames Date: Wed, 25 Aug 2021 15:58:23 +0200 Subject: [PATCH] test: add tests of np.dot with compile_numpy_function_into_op_graph closes #201 --- tests/hnumpy/test_compile.py | 59 +++++++++++++++++++++++++++++++++--- 1 file changed, 54 insertions(+), 5 deletions(-) diff --git a/tests/hnumpy/test_compile.py b/tests/hnumpy/test_compile.py index 353bd7ec7..0a62b19a0 100644 --- a/tests/hnumpy/test_compile.py +++ b/tests/hnumpy/test_compile.py @@ -8,7 +8,7 @@ import pytest from hdk.common.data_types.integers import Integer from hdk.common.debugging import draw_graph, get_printable_graph from hdk.common.extensions.table import LookupTable -from hdk.common.values import EncryptedValue +from hdk.common.values import EncryptedTensor, EncryptedValue from hdk.hnumpy.compile import ( compile_numpy_function, compile_numpy_function_into_op_graph, @@ -43,7 +43,7 @@ def no_fuse_unhandled(x, y): ], ) def test_compile_function_multiple_outputs(function, input_ranges, list_of_arg_names): - """Test function compile_numpy_function for a program with multiple outputs""" + """Test function compile_numpy_function_into_op_graph for a program with multiple outputs""" def data_gen(args): for prod in itertools.product(*args): @@ -131,7 +131,7 @@ def test_compile_and_run_correctness(function, input_ranges, list_of_arg_names): def test_compile_function_with_direct_tlu(): - """Test compile_numpy_function for a program with direct table lookup""" + """Test compile_numpy_function_into_op_graph for a program with direct table lookup""" table = LookupTable([9, 2, 4, 11]) @@ -149,7 +149,7 @@ def test_compile_function_with_direct_tlu(): def test_compile_function_with_direct_tlu_overflow(): - """Test compile_numpy_function for a program with direct table lookup overflow""" + """Test compile_numpy_function_into_op_graph for a program with direct table lookup overflow""" table = LookupTable([9, 2, 4, 11]) @@ -171,7 +171,7 @@ def test_compile_function_with_direct_tlu_overflow(): ], ) def test_fail_compile(function, input_ranges, list_of_arg_names): - """Test function compile_numpy_function for a program with signed values""" + """Test function compile_numpy_function_into_op_graph for a program with signed values""" def data_gen(args): for prod in itertools.product(*args): @@ -187,3 +187,52 @@ def test_fail_compile(function, input_ranges, list_of_arg_names): function_parameters, data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)), ) + + +@pytest.mark.parametrize( + "function,params,shape,ref_graph_str", + [ + # pylint: disable=unnecessary-lambda + ( + lambda x, y: numpy.dot(x, y), + { + "x": EncryptedTensor(Integer(2, is_signed=False), shape=(4,)), + "y": EncryptedTensor(Integer(2, is_signed=False), shape=(4,)), + }, + (4,), + # Remark that, when you do the dot of tensors of 4 values between 0 and 3, + # you can get a maximal value of 4*3*3 = 36, ie something on 6 bits + "\n%0 = x # Integer" + "\n%1 = y # Integer" + "\n%2 = Dot(0, 1) # Integer" + "\nreturn(%2)", + ), + # pylint: enable=unnecessary-lambda + ], +) +def test_compile_function_with_dot(function, params, shape, ref_graph_str): + """Test compile_numpy_function_into_op_graph for a program with np.dot""" + + # This is the exhaust, but if ever we have too long inputs (ie, large 'repeat'), + # we'll have to take random values, not all values one by one + def data_gen(max_for_ij, repeat): + iter_i = itertools.product(range(0, max_for_ij + 1), repeat=repeat) + iter_j = itertools.product(range(0, max_for_ij + 1), repeat=repeat) + for prod_i, prod_j in itertools.product(iter_i, iter_j): + yield (prod_i, prod_j) + + max_for_ij = 3 + assert len(shape) == 1 + repeat = shape[0] + + op_graph = compile_numpy_function_into_op_graph( + function, + params, + data_gen(max_for_ij, repeat), + ) + str_of_the_graph = get_printable_graph(op_graph, show_data_types=True) + assert str_of_the_graph == ref_graph_str, ( + f"\n==================\nGot {str_of_the_graph}" + f"\n==================\nExpected {ref_graph_str}" + f"\n==================\n" + )