mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 20:25:34 -05:00
feat(python): support functions returning tensors
This commit is contained in:
@@ -29,12 +29,28 @@ from zamalang import CompilerEngine
|
||||
20,
|
||||
id="dot_eint_int"
|
||||
),
|
||||
pytest.param(
|
||||
"""
|
||||
func @main(%a0: tensor<4x!HLFHE.eint<6>>, %a1: tensor<4xi7>) -> tensor<4x!HLFHE.eint<6>> {
|
||||
%res = "HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x!HLFHE.eint<6>>, tensor<4xi7>) -> tensor<4x!HLFHE.eint<6>>
|
||||
return %res : tensor<4x!HLFHE.eint<6>>
|
||||
}
|
||||
""",
|
||||
([31, 6, 12, 9], [32, 9, 2, 3]),
|
||||
[63, 15, 14, 12],
|
||||
id="add_eint_int_1D"
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_compile_and_run(mlir_input, args, expected_result):
|
||||
engine = CompilerEngine()
|
||||
engine.compile_fhe(mlir_input)
|
||||
assert engine.run(*args) == expected_result
|
||||
if isinstance(expected_result, int):
|
||||
assert engine.run(*args) == expected_result
|
||||
else:
|
||||
# numpy array on the left
|
||||
assert (engine.run(*args) == expected_result).all()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user