import pytest from zamalang import compiler VALID_INPUTS = [ """ func @add_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> { %0 = constant 1 : i3 %1 = "HLFHE.add_eint_int"(%arg0, %0): (!HLFHE.eint<2>, i3) -> (!HLFHE.eint<2>) return %1: !HLFHE.eint<2> } """, """ func @apply_lookup_table(%arg0: !HLFHE.eint<2>, %arg1: tensor<4xi2>) -> !HLFHE.eint<2> { %1 = "HLFHE.apply_lookup_table"(%arg0, %arg1): (!HLFHE.eint<2>, tensor<4xi2>) -> (!HLFHE.eint<2>) return %1: !HLFHE.eint<2> } """, """ func @dot_eint_int(%arg0: tensor<2x!HLFHE.eint<2>>, %arg1: tensor<2xi3>) -> !HLFHE.eint<2> { %1 = "HLFHE.dot_eint_int"(%arg0, %arg1) : (tensor<2x!HLFHE.eint<2>>, tensor<2xi3>) -> !HLFHE.eint<2> return %1 : !HLFHE.eint<2> } """, """ func @main(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> { %tlu = std.constant dense<[0, 36028797018963968, 72057594037927936, 108086391056891904, 144115188075855872, 180143985094819840, 216172782113783808, 252201579132747776, 288230376151711744, 324259173170675712, 360287970189639680, 396316767208603648, 432345564227567616, 468374361246531584, 504403158265495552, 540431955284459520, 576460752303423488, 612489549322387456, 648518346341351424, 684547143360315392, 720575940379279360, 756604737398243328, 792633534417207296, 828662331436171264, 864691128455135232, 900719925474099200, 936748722493063168, 972777519512027136, 1008806316530991104, 1044835113549955072, 1080863910568919040, 1116892707587883008, 1152921504606846976, 1188950301625810944, 1224979098644774912, 1261007895663738880, 1297036692682702848, 1333065489701666816, 1369094286720630784, 1405123083739594752, 1441151880758558720, 1477180677777522688, 1513209474796486656, 1549238271815450624, 1585267068834414592, 1621295865853378560, 1657324662872342528, 1693353459891306496, 1729382256910270464, 1765411053929234432, 1801439850948198400, 1837468647967162368, 1873497444986126336, 1909526242005090304, 1945555039024054272, 1981583836043018240, 2017612633061982208, 2053641430080946176, 2089670227099910144, 2125699024118874112, 2161727821137838080, 2197756618156802048, 2233785415175766016, 2269814212194729984, 2305843009213693952, 2341871806232657920, 2377900603251621888, 2413929400270585856, 2449958197289549824, 2485986994308513792, 2522015791327477760, 2558044588346441728, 2594073385365405696, 2630102182384369664, 2666130979403333632, 2702159776422297600, 2738188573441261568, 2774217370460225536, 2810246167479189504, 2846274964498153472, 2882303761517117440, 2918332558536081408, 2954361355555045376, 2990390152574009344, 3026418949592973312, 3062447746611937280, 3098476543630901248, 3134505340649865216, 3170534137668829184, 3206562934687793152, 3242591731706757120, 3278620528725721088, 3314649325744685056, 3350678122763649024, 3386706919782612992, 3422735716801576960, 3458764513820540928, 3494793310839504896, 3530822107858468864, 3566850904877432832, 3602879701896396800, 3638908498915360768, 3674937295934324736, 3710966092953288704, 3746994889972252672, 3783023686991216640, 3819052484010180608, 3855081281029144576, 3891110078048108544, 3927138875067072512, 3963167672086036480, 3999196469105000448, 4035225266123964416, 4071254063142928384, 4107282860161892352, 4143311657180856320, 4179340454199820288, 4215369251218784256, 4251398048237748224, 4287426845256712192, 4323455642275676160, 4359484439294640128, 4395513236313604096, 4431542033332568064, 4467570830351532032, 4503599627370496000, 4539628424389459968, 4575657221408423936]> : tensor<128xi64> %1 = "HLFHE.apply_lookup_table"(%arg0, %tlu): (!HLFHE.eint<7>, tensor<128xi64>) -> (!HLFHE.eint<7>) return %1: !HLFHE.eint<7> } """, ] INVALID_INPUTS = [ "nothing really mlir", """ func @test(%arg0: !HLFHE.eint<0>) { return } """, ] @pytest.mark.parametrize("mlir_input", VALID_INPUTS) def test_valid_mlir_inputs(mlir_input): # no need to check that it's correctly parsed, as we already have test for this # we just wanna make sure it doesn't raise an error for valid inputs compiler.round_trip(mlir_input) @pytest.mark.parametrize("mlir_input", INVALID_INPUTS) def test_invalid_mlir_inputs(mlir_input): # We need to check that invalud inputs are raising an error with pytest.raises(RuntimeError, match=r"mlir parsing failed"): compiler.round_trip(mlir_input)