diff --git a/tests/numpy/test_compile.py b/tests/numpy/test_compile.py index 941766c05..4fe476ed0 100644 --- a/tests/numpy/test_compile.py +++ b/tests/numpy/test_compile.py @@ -2021,6 +2021,33 @@ def test_compile_function_with_direct_tlu_overflow(default_compilation_configura ) +@pytest.mark.parametrize( + "input_shape", + [ + pytest.param((4,)), + pytest.param((3, 2)), + pytest.param((3, 2, 5)), + pytest.param((3, 2, 5, 3)), + ], +) +def test_compile_and_run_transpose_correctness(input_shape, default_compilation_configuration): + """Test function to make sure compilation and execution of transpose works properly""" + + def transpose(x): + return numpy.transpose(x) + + compiler_engine = compile_numpy_function( + transpose, + {"x": EncryptedTensor(Integer(64, False), input_shape)}, + [numpy.random.randint(0, 120, size=input_shape) for i in range(20)], + default_compilation_configuration, + ) + x = numpy.random.randint(0, 120, size=input_shape, dtype=numpy.uint8) + expected = transpose(x) + result = compiler_engine.run(x) + assert (expected == result).all() + + # pylint: disable=line-too-long @pytest.mark.parametrize( "function,parameters,inputset,error,match",