feat(compiler/llvm-pipeline): Bufferize linalg and convert to loops

This commit is contained in:
Quentin Bourgerie
2021-08-25 14:10:18 +02:00
parent a654fb2d0e
commit de7129fe8e
2 changed files with 42 additions and 0 deletions

View File

@@ -1,5 +1,6 @@
#include "mlir/Dialect/Tensor/Transforms/Passes.h"
#include <llvm/Support/TargetSelect.h>
#include <mlir/Dialect/Linalg/Passes.h>
#include <mlir/Dialect/StandardOps/Transforms/Passes.h>
#include <mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h>
#include <mlir/Target/LLVMIR/Export.h>
@@ -74,6 +75,10 @@ mlir::LogicalResult CompilerTools::lowerMlirStdsDialectToMlirLLVMDialect(
addFilteredPassToPassManager(pm, mlir::createStdBufferizePass(), enablePass);
addFilteredPassToPassManager(pm, mlir::createTensorBufferizePass(),
enablePass);
addFilteredPassToPassManager(pm, mlir::createLinalgBufferizePass(),
enablePass);
addFilteredPassToPassManager(pm, mlir::createConvertLinalgToLoopsPass(),
enablePass);
addFilteredPassToPassManager(pm, mlir::createFuncBufferizePass(), enablePass);
addFilteredPassToPassManager(pm, mlir::createFinalizingBufferizePass(),
enablePass);

View File

@@ -344,4 +344,41 @@ func @main(%in: tensor<2x!HLFHE.eint<5>>) -> tensor<3x!HLFHE.eint<5>> {
ASSERT_EQ(t_res[0], in[0] + in[0]);
ASSERT_EQ(t_res[1], in[0] + in[1]);
ASSERT_EQ(t_res[2], in[1] + in[1]);
}
TEST(CompileAndRunTensorEncrypted, linalg_generic) {
mlir::zamalang::CompilerEngine engine;
auto mlirStr = R"XXX(
#map0 = affine_map<(d0) -> (d0)>
#map1 = affine_map<(d0) -> (0)>
func @main(%arg0: tensor<2x!HLFHE.eint<7>>, %arg1: tensor<2xi8>, %acc: !HLFHE.eint<7>) -> !HLFHE.eint<7> {
%tacc = tensor.from_elements %acc : tensor<1x!HLFHE.eint<7>>
%2 = linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types = ["reduction"]} ins(%arg0, %arg1 : tensor<2x!HLFHE.eint<7>>, tensor<2xi8>) outs(%tacc : tensor<1x!HLFHE.eint<7>>) {
^bb0(%arg2: !HLFHE.eint<7>, %arg3: i8, %arg4: !HLFHE.eint<7>): // no predecessors
%4 = "HLFHE.mul_eint_int"(%arg2, %arg3) : (!HLFHE.eint<7>, i8) -> !HLFHE.eint<7>
%5 = "HLFHE.add_eint"(%4, %arg4) : (!HLFHE.eint<7>, !HLFHE.eint<7>) -> !HLFHE.eint<7>
linalg.yield %5 : !HLFHE.eint<7>
} -> tensor<1x!HLFHE.eint<7>>
%c0 = constant 0 : index
%ret = tensor.extract %2[%c0] : tensor<1x!HLFHE.eint<7>>
return %ret : !HLFHE.eint<7>
}
)XXX";
ASSERT_LLVM_ERROR(engine.compile(mlirStr));
auto maybeArgument = engine.buildArgument();
ASSERT_LLVM_ERROR(maybeArgument.takeError());
auto argument = std::move(maybeArgument.get());
// Set arg0, arg1, acc
const size_t in_size = 2;
uint8_t arg0[in_size] = {2, 8};
ASSERT_LLVM_ERROR(argument->setArg(0, arg0, in_size));
uint8_t arg1[in_size] = {6, 8};
ASSERT_LLVM_ERROR(argument->setArg(1, arg1, in_size));
ASSERT_LLVM_ERROR(argument->setArg(2, 0));
// Invoke the function
ASSERT_LLVM_ERROR(engine.invoke(*argument));
// Get and assert the result
uint64_t res;
ASSERT_LLVM_ERROR(argument->getResult(0, res));
ASSERT_EQ(res, 76);
}