diff --git a/compiler/include/zamalang/Dialect/HLFHE/Transforms/TensorOpsToLinalg.h b/compiler/include/zamalang/Dialect/HLFHE/Transforms/TensorOpsToLinalg.h new file mode 100644 index 000000000..39f2a10e3 --- /dev/null +++ b/compiler/include/zamalang/Dialect/HLFHE/Transforms/TensorOpsToLinalg.h @@ -0,0 +1,14 @@ +#ifndef ZAMALANG_DIALECT_HLFHE_TRANSFORMS_TENSOROPSTOLINALG_H +#define ZAMALANG_DIALECT_HLFHE_TRANSFORMS_TENSOROPSTOLINALG_H + +#include + +namespace mlir { +namespace zamalang { +namespace HLFHE { +std::unique_ptr createLowerTensorOpsToLinalgPass(); +} // namespace HLFHE +} // namespace zamalang +} // namespace mlir + +#endif diff --git a/compiler/lib/Dialect/HLFHE/CMakeLists.txt b/compiler/lib/Dialect/HLFHE/CMakeLists.txt index f33061b2d..9f57627c3 100644 --- a/compiler/lib/Dialect/HLFHE/CMakeLists.txt +++ b/compiler/lib/Dialect/HLFHE/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/compiler/lib/Dialect/HLFHE/Transforms/CMakeLists.txt b/compiler/lib/Dialect/HLFHE/Transforms/CMakeLists.txt new file mode 100644 index 000000000..02235e7e2 --- /dev/null +++ b/compiler/lib/Dialect/HLFHE/Transforms/CMakeLists.txt @@ -0,0 +1,14 @@ +add_mlir_dialect_library(HLFHEDialectTransforms + TensorOpsToLinalg.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/zamalang/Dialect/HLFHE + + DEPENDS + HLFHEDialect + + LINK_LIBS PUBLIC + MLIRIR + HLFHEDialect) + +target_link_libraries(HLFHEDialect PUBLIC MLIRIR) diff --git a/compiler/lib/Dialect/HLFHE/Transforms/TensorOpsToLinalg.cpp b/compiler/lib/Dialect/HLFHE/Transforms/TensorOpsToLinalg.cpp new file mode 100644 index 000000000..c304d19cd --- /dev/null +++ b/compiler/lib/Dialect/HLFHE/Transforms/TensorOpsToLinalg.cpp @@ -0,0 +1,128 @@ +#include "zamalang/Dialect/HLFHE/Transforms/TensorOpsToLinalg.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/OperationSupport.h" +#include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h" +#include "llvm/ADT/SmallVector.h" +#include +#include +#include +#include +#include +#include + +struct DotToLinalgGeneric : public ::mlir::RewritePattern { + DotToLinalgGeneric(::mlir::MLIRContext *context) + : ::mlir::RewritePattern("HLFHE.dot_eint_int", 1, context, + {"linalg.generic"}) {} + + // This rewrite pattern transforms any instance of + // `HLFHE.dot_eint_int` to an instance of `linalg.generic` with an + // appropriate region using `HLFHE.mul_eint_int` and + // `HLFHE.add_eint` operations and an appropriate specification for + // the iteration dimensions. + // + // Example: + // + // "HLFHE.dot_eint_int"(%arg0, %arg1, %arg2) : + // (memref>, + // memref, + // memref>) -> () + // + // becomes: + // + // linalg.generic { + // indexing_maps = [affine_map<(d0) -> (d0)>, + // affine_map<(d0) -> (d0)>, + // affine_map<(d0) -> ()>], + // iterator_types = ["reduction"] + // } ins(%arg0, %arg1 : memref>, memref) + // outs(%arg2: memref>) + // { + // ^bb0(%arg3: !HLFHE.eint<0>, %arg4: i32, %arg5: !HLFHE.eint<0>): + // %0 = "HLFHE.mul_eint_int"(%arg3, %arg4) : (!HLFHE.eint<0>, i32) -> + // !HLFHE.eint<0> %1 = "HLFHE.add_eint"(%0, %arg5) : (!HLFHE.eint<0>, + // !HLFHE.eint<0>) -> !HLFHE.eint<0> linalg.yield %1 : !HLFHE.eint<0> + // } + // + ::mlir::LogicalResult + matchAndRewrite(::mlir::Operation *op0, + ::mlir::PatternRewriter &rewriter) const override { + ::mlir::zamalang::HLFHE::Dot &&dotOp = + ::llvm::dyn_cast_or_null<::mlir::zamalang::HLFHE::Dot>(op0); + + mlir::TypeRange resTypes{}; + llvm::SmallVector ins{dotOp.lhs(), dotOp.rhs()}; + llvm::SmallVector outs{dotOp.out()}; + + llvm::SmallVector maps{ + mlir::AffineMap::getMultiDimIdentityMap(1, this->getContext()), + mlir::AffineMap::getMultiDimIdentityMap(1, this->getContext()), + mlir::AffineMap::get(1, 0, this->getContext())}; + + llvm::SmallVector itTypes{"reduction"}; + llvm::StringRef doc{""}; + llvm::StringRef call{""}; + + auto regBuilder = [&](mlir::OpBuilder &nestedBuilder, + mlir::Location nestedLoc, + mlir::ValueRange blockArgs) { + mlir::zamalang::HLFHE::MulEintIntOp mul = + nestedBuilder.create( + dotOp.getLoc(), blockArgs[0], blockArgs[1]); + mlir::zamalang::HLFHE::AddEintOp add = + nestedBuilder.create( + dotOp.getLoc(), mul, blockArgs[2]); + + nestedBuilder.create(dotOp.getLoc(), + add.getResult()); + }; + + mlir::linalg::GenericOp gop = rewriter.create( + dotOp.getLoc(), resTypes, ins, outs, maps, itTypes, doc, call, + regBuilder); + + rewriter.replaceOp(op0, {gop.getODSResults(0)}); + + return ::mlir::success(); + }; +}; + +namespace { +struct LowerTensorOpsToLinalgPass + : public mlir::PassWrapper { + void getDependentDialects(mlir::DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnFunction() final; +}; + +void LowerTensorOpsToLinalgPass::runOnFunction() { + mlir::FuncOp function = this->getFunction(); + + mlir::ConversionTarget target(getContext()); + + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + target.addIllegalOp(); + + mlir::OwningRewritePatternList patterns(&getContext()); + patterns.insert(&getContext()); + + if (mlir::applyPartialConversion(function, target, std::move(patterns)) + .failed()) + this->signalPassFailure(); +} + +} // namespace + +namespace mlir { +namespace zamalang { +namespace HLFHE { +std::unique_ptr createLowerTensorOpsToLinalgPass() { + return std::make_unique(); +} +} // namespace HLFHE +} // namespace zamalang +} // namespace mlir