feat(python): support functions returning tensors

This commit is contained in:
youben11
2021-11-04 19:06:11 +01:00
committed by Ayoub Benaissa
parent badc8e44bf
commit b501e3d6c0
5 changed files with 133 additions and 9 deletions

View File

@@ -29,12 +29,28 @@ from zamalang import CompilerEngine
20,
id="dot_eint_int"
),
pytest.param(
"""
func @main(%a0: tensor<4x!HLFHE.eint<6>>, %a1: tensor<4xi7>) -> tensor<4x!HLFHE.eint<6>> {
%res = "HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x!HLFHE.eint<6>>, tensor<4xi7>) -> tensor<4x!HLFHE.eint<6>>
return %res : tensor<4x!HLFHE.eint<6>>
}
""",
([31, 6, 12, 9], [32, 9, 2, 3]),
[63, 15, 14, 12],
id="add_eint_int_1D"
),
],
)
def test_compile_and_run(mlir_input, args, expected_result):
engine = CompilerEngine()
engine.compile_fhe(mlir_input)
assert engine.run(*args) == expected_result
if isinstance(expected_result, int):
assert engine.run(*args) == expected_result
else:
# numpy array on the left
assert (engine.run(*args) == expected_result).all()