test: test transpose correctness

This commit is contained in:
youben11
2022-03-28 14:53:28 +01:00
committed by Umut
parent a2955d29ea
commit a9d2733230

View File

@@ -2021,6 +2021,33 @@ def test_compile_function_with_direct_tlu_overflow(default_compilation_configura
)
@pytest.mark.parametrize(
"input_shape",
[
pytest.param((4,)),
pytest.param((3, 2)),
pytest.param((3, 2, 5)),
pytest.param((3, 2, 5, 3)),
],
)
def test_compile_and_run_transpose_correctness(input_shape, default_compilation_configuration):
"""Test function to make sure compilation and execution of transpose works properly"""
def transpose(x):
return numpy.transpose(x)
compiler_engine = compile_numpy_function(
transpose,
{"x": EncryptedTensor(Integer(64, False), input_shape)},
[numpy.random.randint(0, 120, size=input_shape) for i in range(20)],
default_compilation_configuration,
)
x = numpy.random.randint(0, 120, size=input_shape, dtype=numpy.uint8)
expected = transpose(x)
result = compiler_engine.run(x)
assert (expected == result).all()
# pylint: disable=line-too-long
@pytest.mark.parametrize(
"function,parameters,inputset,error,match",