diff --git a/compiler/lib/Support/CompilerTools.cpp b/compiler/lib/Support/CompilerTools.cpp index 99849468c..a838a2880 100644 --- a/compiler/lib/Support/CompilerTools.cpp +++ b/compiler/lib/Support/CompilerTools.cpp @@ -1,5 +1,6 @@ #include "mlir/Dialect/Tensor/Transforms/Passes.h" #include +#include #include #include #include @@ -74,6 +75,10 @@ mlir::LogicalResult CompilerTools::lowerMlirStdsDialectToMlirLLVMDialect( addFilteredPassToPassManager(pm, mlir::createStdBufferizePass(), enablePass); addFilteredPassToPassManager(pm, mlir::createTensorBufferizePass(), enablePass); + addFilteredPassToPassManager(pm, mlir::createLinalgBufferizePass(), + enablePass); + addFilteredPassToPassManager(pm, mlir::createConvertLinalgToLoopsPass(), + enablePass); addFilteredPassToPassManager(pm, mlir::createFuncBufferizePass(), enablePass); addFilteredPassToPassManager(pm, mlir::createFinalizingBufferizePass(), enablePass); diff --git a/compiler/tests/unittest/end_to_end_jit_test.cc b/compiler/tests/unittest/end_to_end_jit_test.cc index 429b346d5..1334ceab3 100644 --- a/compiler/tests/unittest/end_to_end_jit_test.cc +++ b/compiler/tests/unittest/end_to_end_jit_test.cc @@ -344,4 +344,41 @@ func @main(%in: tensor<2x!HLFHE.eint<5>>) -> tensor<3x!HLFHE.eint<5>> { ASSERT_EQ(t_res[0], in[0] + in[0]); ASSERT_EQ(t_res[1], in[0] + in[1]); ASSERT_EQ(t_res[2], in[1] + in[1]); +} + +TEST(CompileAndRunTensorEncrypted, linalg_generic) { + mlir::zamalang::CompilerEngine engine; + auto mlirStr = R"XXX( +#map0 = affine_map<(d0) -> (d0)> +#map1 = affine_map<(d0) -> (0)> +func @main(%arg0: tensor<2x!HLFHE.eint<7>>, %arg1: tensor<2xi8>, %acc: !HLFHE.eint<7>) -> !HLFHE.eint<7> { + %tacc = tensor.from_elements %acc : tensor<1x!HLFHE.eint<7>> + %2 = linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types = ["reduction"]} ins(%arg0, %arg1 : tensor<2x!HLFHE.eint<7>>, tensor<2xi8>) outs(%tacc : tensor<1x!HLFHE.eint<7>>) { + ^bb0(%arg2: !HLFHE.eint<7>, %arg3: i8, %arg4: !HLFHE.eint<7>): // no predecessors + %4 = "HLFHE.mul_eint_int"(%arg2, %arg3) : (!HLFHE.eint<7>, i8) -> !HLFHE.eint<7> + %5 = "HLFHE.add_eint"(%4, %arg4) : (!HLFHE.eint<7>, !HLFHE.eint<7>) -> !HLFHE.eint<7> + linalg.yield %5 : !HLFHE.eint<7> + } -> tensor<1x!HLFHE.eint<7>> + %c0 = constant 0 : index + %ret = tensor.extract %2[%c0] : tensor<1x!HLFHE.eint<7>> + return %ret : !HLFHE.eint<7> +} +)XXX"; + ASSERT_LLVM_ERROR(engine.compile(mlirStr)); + auto maybeArgument = engine.buildArgument(); + ASSERT_LLVM_ERROR(maybeArgument.takeError()); + auto argument = std::move(maybeArgument.get()); + // Set arg0, arg1, acc + const size_t in_size = 2; + uint8_t arg0[in_size] = {2, 8}; + ASSERT_LLVM_ERROR(argument->setArg(0, arg0, in_size)); + uint8_t arg1[in_size] = {6, 8}; + ASSERT_LLVM_ERROR(argument->setArg(1, arg1, in_size)); + ASSERT_LLVM_ERROR(argument->setArg(2, 0)); + // Invoke the function + ASSERT_LLVM_ERROR(engine.invoke(*argument)); + // Get and assert the result + uint64_t res; + ASSERT_LLVM_ERROR(argument->getResult(0, res)); + ASSERT_EQ(res, 76); } \ No newline at end of file