mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
test: dot compilation and execution
This commit is contained in:
@@ -9,7 +9,7 @@ from concrete.common.compilation import CompilationConfiguration
|
||||
from concrete.common.data_types.integers import Integer
|
||||
from concrete.common.debugging import draw_graph, get_printable_graph
|
||||
from concrete.common.extensions.table import LookupTable
|
||||
from concrete.common.values import EncryptedScalar, EncryptedTensor
|
||||
from concrete.common.values import ClearTensor, EncryptedScalar, EncryptedTensor
|
||||
from concrete.numpy.compile import (
|
||||
compile_numpy_function,
|
||||
compile_numpy_function_into_op_graph,
|
||||
@@ -139,6 +139,54 @@ def test_compile_and_run_correctness(function, input_ranges, list_of_arg_names):
|
||||
assert compiler_engine.run(*args) == function(*args)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"size, input_range",
|
||||
[
|
||||
pytest.param(
|
||||
1,
|
||||
(0, 8),
|
||||
),
|
||||
pytest.param(
|
||||
4,
|
||||
(0, 8),
|
||||
),
|
||||
pytest.param(
|
||||
8,
|
||||
(0, 8),
|
||||
),
|
||||
pytest.param(
|
||||
16,
|
||||
(0, 4),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_compile_and_run_dot_correctness(size, input_range):
|
||||
"""Test correctness of results when running a compiled function"""
|
||||
|
||||
def data_gen(input_range, size):
|
||||
for i in range(*input_range, size):
|
||||
vec = list(range(i, min(i + size, input_range[1])))
|
||||
yield vec, vec[::-1]
|
||||
|
||||
function_parameters = {
|
||||
"x": EncryptedTensor(Integer(64, False), (size,)),
|
||||
"y": ClearTensor(Integer(64, False), (size,)),
|
||||
}
|
||||
|
||||
def function(x, y):
|
||||
return numpy.dot(x, y)
|
||||
|
||||
compiler_engine = compile_numpy_function(
|
||||
function,
|
||||
function_parameters,
|
||||
data_gen(input_range, size),
|
||||
)
|
||||
|
||||
low, high = input_range
|
||||
args = [[random.randint(low, high) for _ in range(size)] for __ in range(2)]
|
||||
assert compiler_engine.run(*args) == function(*args)
|
||||
|
||||
|
||||
def test_compile_function_with_direct_tlu():
|
||||
"""Test compile_numpy_function_into_op_graph for a program with direct table lookup"""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user