mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 20:25:34 -05:00
59 lines
4.3 KiB
Python
59 lines
4.3 KiB
Python
import pytest
|
|
from zamalang import compiler
|
|
|
|
|
|
VALID_INPUTS = [
|
|
"""
|
|
func @add_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> {
|
|
%0 = constant 1 : i3
|
|
%1 = "HLFHE.add_eint_int"(%arg0, %0): (!HLFHE.eint<2>, i3) -> (!HLFHE.eint<2>)
|
|
return %1: !HLFHE.eint<2>
|
|
}
|
|
""",
|
|
"""
|
|
func @apply_lookup_table(%arg0: !HLFHE.eint<2>, %arg1: tensor<4xi2>) -> !HLFHE.eint<2> {
|
|
%1 = "HLFHE.apply_lookup_table"(%arg0, %arg1): (!HLFHE.eint<2>, tensor<4xi2>) -> (!HLFHE.eint<2>)
|
|
return %1: !HLFHE.eint<2>
|
|
}
|
|
""",
|
|
"""
|
|
func @dot_eint_int(%arg0: tensor<2x!HLFHE.eint<2>>,
|
|
%arg1: tensor<2xi3>) -> !HLFHE.eint<2>
|
|
{
|
|
%1 = "HLFHE.dot_eint_int"(%arg0, %arg1) :
|
|
(tensor<2x!HLFHE.eint<2>>, tensor<2xi3>) -> !HLFHE.eint<2>
|
|
return %1 : !HLFHE.eint<2>
|
|
}
|
|
""",
|
|
"""
|
|
func @main(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> {
|
|
%tlu = std.constant dense<[0, 36028797018963968, 72057594037927936, 108086391056891904, 144115188075855872, 180143985094819840, 216172782113783808, 252201579132747776, 288230376151711744, 324259173170675712, 360287970189639680, 396316767208603648, 432345564227567616, 468374361246531584, 504403158265495552, 540431955284459520, 576460752303423488, 612489549322387456, 648518346341351424, 684547143360315392, 720575940379279360, 756604737398243328, 792633534417207296, 828662331436171264, 864691128455135232, 900719925474099200, 936748722493063168, 972777519512027136, 1008806316530991104, 1044835113549955072, 1080863910568919040, 1116892707587883008, 1152921504606846976, 1188950301625810944, 1224979098644774912, 1261007895663738880, 1297036692682702848, 1333065489701666816, 1369094286720630784, 1405123083739594752, 1441151880758558720, 1477180677777522688, 1513209474796486656, 1549238271815450624, 1585267068834414592, 1621295865853378560, 1657324662872342528, 1693353459891306496, 1729382256910270464, 1765411053929234432, 1801439850948198400, 1837468647967162368, 1873497444986126336, 1909526242005090304, 1945555039024054272, 1981583836043018240, 2017612633061982208, 2053641430080946176, 2089670227099910144, 2125699024118874112, 2161727821137838080, 2197756618156802048, 2233785415175766016, 2269814212194729984, 2305843009213693952, 2341871806232657920, 2377900603251621888, 2413929400270585856, 2449958197289549824, 2485986994308513792, 2522015791327477760, 2558044588346441728, 2594073385365405696, 2630102182384369664, 2666130979403333632, 2702159776422297600, 2738188573441261568, 2774217370460225536, 2810246167479189504, 2846274964498153472, 2882303761517117440, 2918332558536081408, 2954361355555045376, 2990390152574009344, 3026418949592973312, 3062447746611937280, 3098476543630901248, 3134505340649865216, 3170534137668829184, 3206562934687793152, 3242591731706757120, 3278620528725721088, 3314649325744685056, 3350678122763649024, 3386706919782612992, 3422735716801576960, 3458764513820540928, 3494793310839504896, 3530822107858468864, 3566850904877432832, 3602879701896396800, 3638908498915360768, 3674937295934324736, 3710966092953288704, 3746994889972252672, 3783023686991216640, 3819052484010180608, 3855081281029144576, 3891110078048108544, 3927138875067072512, 3963167672086036480, 3999196469105000448, 4035225266123964416, 4071254063142928384, 4107282860161892352, 4143311657180856320, 4179340454199820288, 4215369251218784256, 4251398048237748224, 4287426845256712192, 4323455642275676160, 4359484439294640128, 4395513236313604096, 4431542033332568064, 4467570830351532032, 4503599627370496000, 4539628424389459968, 4575657221408423936]> : tensor<128xi64>
|
|
%1 = "HLFHE.apply_lookup_table"(%arg0, %tlu): (!HLFHE.eint<7>, tensor<128xi64>) -> (!HLFHE.eint<7>)
|
|
return %1: !HLFHE.eint<7>
|
|
}
|
|
""",
|
|
]
|
|
|
|
INVALID_INPUTS = [
|
|
"nothing really mlir",
|
|
"""
|
|
func @test(%arg0: !HLFHE.eint<0>) {
|
|
return
|
|
}
|
|
""",
|
|
]
|
|
|
|
|
|
@pytest.mark.parametrize("mlir_input", VALID_INPUTS)
|
|
def test_valid_mlir_inputs(mlir_input):
|
|
# no need to check that it's correctly parsed, as we already have test for this
|
|
# we just wanna make sure it doesn't raise an error for valid inputs
|
|
compiler.round_trip(mlir_input)
|
|
|
|
|
|
@pytest.mark.parametrize("mlir_input", INVALID_INPUTS)
|
|
def test_invalid_mlir_inputs(mlir_input):
|
|
# We need to check that invalud inputs are raising an error
|
|
with pytest.raises(RuntimeError, match=r"mlir parsing failed"):
|
|
compiler.round_trip(mlir_input)
|