From bc975d904ea073c3425fa6e719c1ffdf53632e2c Mon Sep 17 00:00:00 2001 From: Quentin Bourgerie Date: Wed, 18 Aug 2021 12:15:55 +0200 Subject: [PATCH] feat(compiler): introduce bufferization passes in lowering pipeline to llvm --- .../LowLFHEToConcreteCAPI.cpp | 2 +- compiler/lib/Support/CompilerTools.cpp | 14 ++++++++++++- compiler/tests/RunJit/tensor_cst.mlir | 21 +++++++++++++++++++ 3 files changed, 35 insertions(+), 2 deletions(-) create mode 100644 compiler/tests/RunJit/tensor_cst.mlir diff --git a/compiler/lib/Conversion/LowLFHEToConcreteCAPI/LowLFHEToConcreteCAPI.cpp b/compiler/lib/Conversion/LowLFHEToConcreteCAPI/LowLFHEToConcreteCAPI.cpp index a3483587f..47d2e341d 100644 --- a/compiler/lib/Conversion/LowLFHEToConcreteCAPI/LowLFHEToConcreteCAPI.cpp +++ b/compiler/lib/Conversion/LowLFHEToConcreteCAPI/LowLFHEToConcreteCAPI.cpp @@ -220,7 +220,7 @@ void LowLFHEToConcreteCAPIPass::runOnOperation() { // Apply the conversion mlir::ModuleOp op = getOperation(); - if (mlir::applyFullConversion(op, target, std::move(patterns)).failed()) { + if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) { this->signalPassFailure(); } } diff --git a/compiler/lib/Support/CompilerTools.cpp b/compiler/lib/Support/CompilerTools.cpp index 03f1571ad..829945d95 100644 --- a/compiler/lib/Support/CompilerTools.cpp +++ b/compiler/lib/Support/CompilerTools.cpp @@ -1,6 +1,9 @@ +#include "mlir/Dialect/Tensor/Transforms/Passes.h" #include +#include #include #include +#include #include "zamalang/Conversion/Passes.h" #include "zamalang/Support/CompilerTools.h" @@ -63,8 +66,17 @@ mlir::LogicalResult CompilerTools::lowerHLFHEToMlirStdsDialect( mlir::LogicalResult CompilerTools::lowerMlirStdsDialectToMlirLLVMDialect( mlir::MLIRContext &context, mlir::Operation *module, llvm::function_ref enablePass) { - mlir::PassManager pm(&context); + + // Bufferize + addFilteredPassToPassManager(pm, mlir::createTensorConstantBufferizePass(), + enablePass); + addFilteredPassToPassManager(pm, mlir::createStdBufferizePass(), enablePass); + addFilteredPassToPassManager(pm, mlir::createTensorBufferizePass(), + enablePass); + addFilteredPassToPassManager(pm, mlir::createFuncBufferizePass(), enablePass); + addFilteredPassToPassManager(pm, mlir::createFinalizingBufferizePass(), + enablePass); addFilteredPassToPassManager( pm, mlir::zamalang::createConvertMLIRLowerableDialectsToLLVMPass(), enablePass); diff --git a/compiler/tests/RunJit/tensor_cst.mlir b/compiler/tests/RunJit/tensor_cst.mlir new file mode 100644 index 000000000..8405ce137 --- /dev/null +++ b/compiler/tests/RunJit/tensor_cst.mlir @@ -0,0 +1,21 @@ +// RUN: zamacompiler %s --run-jit --jit-args 11 2>&1| FileCheck %s + +// CHECK-LABEL: 116 +func @main(%arg0: index) -> i7 { + %t = std.constant dense<[127, 126, 125, 124, 123, 122, 121, 120, 119, 118, 117, 116, 115, 114, 113, 112, 111, 110, 109, 108, 107, 106, 105, 104, 103, 102, 101, 100, 99, 98, 97, 96, 95, 94, 93, 92, 91, 90, 89, 88, 87, 86, 85, 84, 83, 82, 81, 80, 79, 78, 77, 76, 75, 74, 73, 72, 71, 70, 69, 68, 67, 66, 65, 64, 63, 62, 61, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0]> : tensor<128xi7> + %c = tensor.extract %t[%arg0] : tensor<128xi7> + return %c : i7 +} + +// // ----- + +// func @extract(%arg0: index, %t: tensor<128xi7>) -> i7{ +// %c = tensor.extract %t[%arg0] : tensor<128xi7> +// return %c : i7 +// } + +// func @main(%arg0: index) -> i7 { +// %t = std.constant dense<[127, 126, 125, 124, 123, 122, 121, 120, 119, 118, 117, 116, 115, 114, 113, 112, 111, 110, 109, 108, 107, 106, 105, 104, 103, 102, 101, 100, 99, 98, 97, 96, 95, 94, 93, 92, 91, 90, 89, 88, 87, 86, 85, 84, 83, 82, 81, 80, 79, 78, 77, 76, 75, 74, 73, 72, 71, 70, 69, 68, 67, 66, 65, 64, 63, 62, 61, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0]> : tensor<128xi7> +// %c = call @extract(%arg0, %t): (index, tensor<128xi7>) -> i7 +// return %c : i7 +// } \ No newline at end of file