mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
test: add tests of np.dot with compile_numpy_function_into_op_graph
closes #201
This commit is contained in:
committed by
Benoit Chevallier
parent
1ebbd3ea91
commit
9a3e15e89a
@@ -8,7 +8,7 @@ import pytest
|
||||
from hdk.common.data_types.integers import Integer
|
||||
from hdk.common.debugging import draw_graph, get_printable_graph
|
||||
from hdk.common.extensions.table import LookupTable
|
||||
from hdk.common.values import EncryptedValue
|
||||
from hdk.common.values import EncryptedTensor, EncryptedValue
|
||||
from hdk.hnumpy.compile import (
|
||||
compile_numpy_function,
|
||||
compile_numpy_function_into_op_graph,
|
||||
@@ -43,7 +43,7 @@ def no_fuse_unhandled(x, y):
|
||||
],
|
||||
)
|
||||
def test_compile_function_multiple_outputs(function, input_ranges, list_of_arg_names):
|
||||
"""Test function compile_numpy_function for a program with multiple outputs"""
|
||||
"""Test function compile_numpy_function_into_op_graph for a program with multiple outputs"""
|
||||
|
||||
def data_gen(args):
|
||||
for prod in itertools.product(*args):
|
||||
@@ -131,7 +131,7 @@ def test_compile_and_run_correctness(function, input_ranges, list_of_arg_names):
|
||||
|
||||
|
||||
def test_compile_function_with_direct_tlu():
|
||||
"""Test compile_numpy_function for a program with direct table lookup"""
|
||||
"""Test compile_numpy_function_into_op_graph for a program with direct table lookup"""
|
||||
|
||||
table = LookupTable([9, 2, 4, 11])
|
||||
|
||||
@@ -149,7 +149,7 @@ def test_compile_function_with_direct_tlu():
|
||||
|
||||
|
||||
def test_compile_function_with_direct_tlu_overflow():
|
||||
"""Test compile_numpy_function for a program with direct table lookup overflow"""
|
||||
"""Test compile_numpy_function_into_op_graph for a program with direct table lookup overflow"""
|
||||
|
||||
table = LookupTable([9, 2, 4, 11])
|
||||
|
||||
@@ -171,7 +171,7 @@ def test_compile_function_with_direct_tlu_overflow():
|
||||
],
|
||||
)
|
||||
def test_fail_compile(function, input_ranges, list_of_arg_names):
|
||||
"""Test function compile_numpy_function for a program with signed values"""
|
||||
"""Test function compile_numpy_function_into_op_graph for a program with signed values"""
|
||||
|
||||
def data_gen(args):
|
||||
for prod in itertools.product(*args):
|
||||
@@ -187,3 +187,52 @@ def test_fail_compile(function, input_ranges, list_of_arg_names):
|
||||
function_parameters,
|
||||
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,params,shape,ref_graph_str",
|
||||
[
|
||||
# pylint: disable=unnecessary-lambda
|
||||
(
|
||||
lambda x, y: numpy.dot(x, y),
|
||||
{
|
||||
"x": EncryptedTensor(Integer(2, is_signed=False), shape=(4,)),
|
||||
"y": EncryptedTensor(Integer(2, is_signed=False), shape=(4,)),
|
||||
},
|
||||
(4,),
|
||||
# Remark that, when you do the dot of tensors of 4 values between 0 and 3,
|
||||
# you can get a maximal value of 4*3*3 = 36, ie something on 6 bits
|
||||
"\n%0 = x # Integer<unsigned, 2 bits>"
|
||||
"\n%1 = y # Integer<unsigned, 2 bits>"
|
||||
"\n%2 = Dot(0, 1) # Integer<unsigned, 6 bits>"
|
||||
"\nreturn(%2)",
|
||||
),
|
||||
# pylint: enable=unnecessary-lambda
|
||||
],
|
||||
)
|
||||
def test_compile_function_with_dot(function, params, shape, ref_graph_str):
|
||||
"""Test compile_numpy_function_into_op_graph for a program with np.dot"""
|
||||
|
||||
# This is the exhaust, but if ever we have too long inputs (ie, large 'repeat'),
|
||||
# we'll have to take random values, not all values one by one
|
||||
def data_gen(max_for_ij, repeat):
|
||||
iter_i = itertools.product(range(0, max_for_ij + 1), repeat=repeat)
|
||||
iter_j = itertools.product(range(0, max_for_ij + 1), repeat=repeat)
|
||||
for prod_i, prod_j in itertools.product(iter_i, iter_j):
|
||||
yield (prod_i, prod_j)
|
||||
|
||||
max_for_ij = 3
|
||||
assert len(shape) == 1
|
||||
repeat = shape[0]
|
||||
|
||||
op_graph = compile_numpy_function_into_op_graph(
|
||||
function,
|
||||
params,
|
||||
data_gen(max_for_ij, repeat),
|
||||
)
|
||||
str_of_the_graph = get_printable_graph(op_graph, show_data_types=True)
|
||||
assert str_of_the_graph == ref_graph_str, (
|
||||
f"\n==================\nGot {str_of_the_graph}"
|
||||
f"\n==================\nExpected {ref_graph_str}"
|
||||
f"\n==================\n"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user