test: add dot tests with non program inputs args

- also fix some inputsets which had wrongly shaped inputs
This commit is contained in:
Arthur Meyre
2021-12-24 17:18:50 +01:00
committed by Benoit Chevallier
parent b48165045a
commit 11c38efa1e

View File

@@ -1157,7 +1157,12 @@ def test_compile_and_run_dot_correctness(size, input_range, default_compilation_
),
]
for _ in range(8):
inputset.append((numpy.random.randint(low, high + 1), numpy.random.randint(low, high + 1)))
inputset.append(
(
numpy.random.randint(low, high + 1, size=shape),
numpy.random.randint(low, high + 1, size=shape),
)
)
function_parameters = {
"x": EncryptedTensor(Integer(64, False), shape),
@@ -1167,15 +1172,21 @@ def test_compile_and_run_dot_correctness(size, input_range, default_compilation_
def function(x, y):
return numpy.dot(x, y)
compiler_engine = compile_numpy_function(
function,
function_parameters,
inputset,
default_compilation_configuration,
)
def function_indirect_args(x, y):
return numpy.dot(x.flatten(), y.flatten())
args = [numpy.random.randint(low, high + 1, size=(size,), dtype=numpy.uint8) for __ in range(2)]
assert compiler_engine.run(*args) == function(*args)
for func_to_compile in [function, function_indirect_args]:
compiler_engine = compile_numpy_function(
func_to_compile,
function_parameters,
inputset,
default_compilation_configuration,
)
args = [
numpy.random.randint(low, high + 1, size=shape, dtype=numpy.uint8) for __ in range(2)
]
assert compiler_engine.run(*args) == func_to_compile(*args)
@pytest.mark.parametrize(