Files
concrete/compiler/tests/python/test_round_trip.py
2021-08-17 16:53:32 +02:00

52 lines
1.5 KiB
Python

import pytest
from zamalang import compiler
VALID_INPUTS = [
"""
func @add_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> {
%0 = constant 1 : i3
%1 = "HLFHE.add_eint_int"(%arg0, %0): (!HLFHE.eint<2>, i3) -> (!HLFHE.eint<2>)
return %1: !HLFHE.eint<2>
}
""",
"""
func @apply_lookup_table(%arg0: !HLFHE.eint<2>, %arg1: memref<4xi2>) -> !HLFHE.eint<2> {
%1 = "HLFHE.apply_lookup_table"(%arg0, %arg1): (!HLFHE.eint<2>, memref<4xi2>) -> (!HLFHE.eint<2>)
return %1: !HLFHE.eint<2>
}
""",
"""
func @dot_eint_int(%arg0: tensor<2x!HLFHE.eint<2>>,
%arg1: tensor<2xi3>) -> !HLFHE.eint<2>
{
%1 = "HLFHE.dot_eint_int"(%arg0, %arg1) :
(tensor<2x!HLFHE.eint<2>>, tensor<2xi3>) -> !HLFHE.eint<2>
return %1 : !HLFHE.eint<2>
}
""",
]
INVALID_INPUTS = [
"nothing really mlir",
"""
func @test(%arg0: !HLFHE.eint<0>) {
return
}
""",
]
@pytest.mark.parametrize("mlir_input", VALID_INPUTS)
def test_valid_mlir_inputs(mlir_input):
# no need to check that it's correctly parsed, as we already have test for this
# we just wanna make sure it doesn't raise an error for valid inputs
compiler.round_trip(mlir_input)
@pytest.mark.parametrize("mlir_input", INVALID_INPUTS)
def test_invalid_mlir_inputs(mlir_input):
# We need to check that invalud inputs are raising an error
with pytest.raises(RuntimeError, match=r"mlir parsing failed"):
compiler.round_trip(mlir_input)