mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 20:25:34 -05:00
test: test transpose correctness
This commit is contained in:
@@ -2021,6 +2021,33 @@ def test_compile_function_with_direct_tlu_overflow(default_compilation_configura
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_shape",
|
||||
[
|
||||
pytest.param((4,)),
|
||||
pytest.param((3, 2)),
|
||||
pytest.param((3, 2, 5)),
|
||||
pytest.param((3, 2, 5, 3)),
|
||||
],
|
||||
)
|
||||
def test_compile_and_run_transpose_correctness(input_shape, default_compilation_configuration):
|
||||
"""Test function to make sure compilation and execution of transpose works properly"""
|
||||
|
||||
def transpose(x):
|
||||
return numpy.transpose(x)
|
||||
|
||||
compiler_engine = compile_numpy_function(
|
||||
transpose,
|
||||
{"x": EncryptedTensor(Integer(64, False), input_shape)},
|
||||
[numpy.random.randint(0, 120, size=input_shape) for i in range(20)],
|
||||
default_compilation_configuration,
|
||||
)
|
||||
x = numpy.random.randint(0, 120, size=input_shape, dtype=numpy.uint8)
|
||||
expected = transpose(x)
|
||||
result = compiler_engine.run(x)
|
||||
assert (expected == result).all()
|
||||
|
||||
|
||||
# pylint: disable=line-too-long
|
||||
@pytest.mark.parametrize(
|
||||
"function,parameters,inputset,error,match",
|
||||
|
||||
Reference in New Issue
Block a user