#include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/SmallVector.h" #include #include "zamalang/Conversion/Passes.h" #include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h" #include "zamalang/Dialect/HLFHE/IR/HLFHEOps.h" #include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgDialect.h" #include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.h" struct DotToLinalgGeneric : public ::mlir::OpRewritePattern { DotToLinalgGeneric(::mlir::MLIRContext *context) : ::mlir::OpRewritePattern<::mlir::zamalang::HLFHELinalg::Dot>(context, 1) {} // This rewrite pattern transforms any instance of // `HLFHELinalg.dot_eint_int` to an instance of `linalg.generic` with an // appropriate region using `HLFHE.mul_eint_int` and // `HLFHE.add_eint` operations, an appropriate specification for the // iteration dimensions and appropriate operations managing the // accumulator of `linalg.generic`. // // Example: // // %o = "HLFHELinalg.dot_eint_int"(%arg0, %arg1) : // (tensor<4x!HLFHE.eint<0>>, // tensor<4xi32>) -> (!HLFHE.eint<0>) // // becomes: // // %0 = "HLFHE.zero"() : () -> !HLFHE.eint<0> // %1 = tensor.from_elements %0 : tensor<1x!HLFHE.eint<0>> // %2 = linalg.generic { // indexing_maps = [#map0, #map0, #map1], // iterator_types = ["reduction"] // } // ins(%arg0, %arg1 : tensor<2x!HLFHE.eint<0>>, tensor<2xi32>) // outs(%1 : tensor<1x!HLFHE.eint<0>>) { // ^bb0(%arg2: !HLFHE.eint<0>, %arg3: i32, %arg4: !HLFHE.eint<0>): // %4 = "HLFHE.mul_eint_int"(%arg2, %arg3) : // (!HLFHE.eint<0>, i32) -> !HLFHE.eint<0> // // %5 = "HLFHE.add_eint"(%4, %arg4) : // (!HLFHE.eint<0>, !HLFHE.eint<0>) -> !HLFHE.eint<0> // // linalg.yield %5 : !HLFHE.eint<0> // } -> tensor<1x!HLFHE.eint<0>> // // %c0 = constant 0 : index // %o = tensor.extract %2[%c0] : tensor<1x!HLFHE.eint<0>> // ::mlir::LogicalResult matchAndRewrite(::mlir::zamalang::HLFHELinalg::Dot dotOp, ::mlir::PatternRewriter &rewriter) const override { // Zero value to initialize accumulator mlir::Value zeroCst = rewriter.create( dotOp.getLoc(), dotOp.lhs().getType().cast().getElementType()); // Create one-dimensional accumulator with a single element // (`tensor.from_elements` does not allow for the creation of 0d // tensors) mlir::tensor::FromElementsOp feOp = rewriter.create(dotOp.getLoc(), zeroCst); mlir::Value accu = feOp.getResult(); // Create `linalg.generic` op llvm::SmallVector resTypes{accu.getType()}; llvm::SmallVector ins{dotOp.lhs(), dotOp.rhs()}; llvm::SmallVector outs{accu}; llvm::SmallVector maps{ mlir::AffineMap::getMultiDimIdentityMap(1, this->getContext()), mlir::AffineMap::getMultiDimIdentityMap(1, this->getContext()), mlir::AffineMap::get(1, 0, {rewriter.getAffineConstantExpr(0)}, this->getContext())}; llvm::SmallVector itTypes{"reduction"}; llvm::StringRef doc{""}; llvm::StringRef call{""}; auto regBuilder = [&](mlir::OpBuilder &nestedBuilder, mlir::Location nestedLoc, mlir::ValueRange blockArgs) { mlir::zamalang::HLFHE::MulEintIntOp mul = nestedBuilder.create( dotOp.getLoc(), blockArgs[0], blockArgs[1]); mlir::zamalang::HLFHE::AddEintOp add = nestedBuilder.create( dotOp.getLoc(), mul, blockArgs[2]); nestedBuilder.create(dotOp.getLoc(), add.getResult()); }; mlir::linalg::GenericOp gop = rewriter.create( dotOp.getLoc(), resTypes, ins, outs, maps, itTypes, doc, call, regBuilder); // Return value is still a 1-dimensional tensor; extract first // element and use it as a replacement for the result of the dot // operation mlir::Value idx0 = rewriter.create(dotOp.getLoc(), 0); llvm::SmallVector indexes{idx0}; mlir::Value res = rewriter.create( dotOp.getLoc(), gop.getResult(0), indexes); rewriter.replaceOp(dotOp, {res}); return ::mlir::success(); }; }; mlir::AffineMap getBroadcastedAffineMap(const mlir::RankedTensorType &resultType, const mlir::RankedTensorType &operandType, ::mlir::PatternRewriter &rewriter) { mlir::SmallVector affineExprs; auto resultShape = resultType.getShape(); auto operandShape = operandType.getShape(); affineExprs.reserve(resultShape.size()); size_t deltaNumDim = resultShape.size() - operandShape.size(); for (auto i = 0; i < operandShape.size(); i++) { if (operandShape[i] == 1 && resultShape[i + deltaNumDim] != 1) { affineExprs.push_back(rewriter.getAffineConstantExpr(0)); } else { affineExprs.push_back(rewriter.getAffineDimExpr(i + deltaNumDim)); } } return mlir::AffineMap::get(resultShape.size(), 0, affineExprs, rewriter.getContext()); } // This template rewrite pattern transforms any instance of // operators `HLFHELinalgOp` that implements the broadasting rules to an // instance of `linalg.generic` with an appropriate region using `HLFHEOp` // operation, an appropriate specification for the iteration dimensions and // appropriate operations managing the accumulator of `linalg.generic`. // // Example: // // %res = HLFHELinalg.op(%lhs, %rhs): // (tensor>, tensor) // -> tensor> // // becomes: // // #maps_0 = [ // affine_map<(a$R", ..., a$A, ..., a1) -> // (dim(lhs, $A) == 1 ? 0 : a$A,..., dim(lhs, 1) == 1 ? 0 : a1)>, // affine_map<(a$R", ..., a1) -> // (dim(rhs, $B') == 1 ? 0 : a$B', ..., dim(rhs, 1) == 1 ? 0 : a1)>, // affine_map<(a$R", ..., a1) -> (a$R", ..., a1) // ] // #attributes_0 { // indexing_maps = #maps_0, // iterator_types = ["parallel", ..., "parallel"], // $R" parallel // } // %init = linalg.init_tensor [DR",...,D1"] // : tensor> // %res = linalg.generic { // ins(%lhs, %rhs: tensor>,tensor) // outs(%init : tensor>) // { // ^bb0(%arg0: !HLFHE.eint

, %arg1: T): // %0 = HLFHE.op(%arg0, %arg1): !HLFHE.eint

, T -> // !HLFHE.eint

// linalg.yield %0 : !HLFHE.eint

// } // } // template struct HLFHELinalgOpToLinalgGeneric : public mlir::OpRewritePattern { HLFHELinalgOpToLinalgGeneric(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) : ::mlir::OpRewritePattern(context, benefit) {} ::mlir::LogicalResult matchAndRewrite(HLFHELinalgOp linalgOp, ::mlir::PatternRewriter &rewriter) const override { mlir::RankedTensorType resultTy = ((mlir::Type)linalgOp->getResult(0).getType()) .cast(); mlir::RankedTensorType lhsTy = ((mlir::Type)linalgOp.lhs().getType()).cast(); mlir::RankedTensorType rhsTy = ((mlir::Type)linalgOp.rhs().getType()).cast(); // linalg.init_tensor for initial value mlir::Value init = rewriter.create( linalgOp.getLoc(), resultTy.getShape(), resultTy.getElementType()); // Create the affine #maps_0 llvm::SmallVector maps{ getBroadcastedAffineMap(resultTy, lhsTy, rewriter), getBroadcastedAffineMap(resultTy, rhsTy, rewriter), getBroadcastedAffineMap(resultTy, resultTy, rewriter), }; // Create the iterator_types llvm::SmallVector iteratorTypes(resultTy.getShape().size(), "parallel"); // Create the body of the `linalg.generic` op auto bodyBuilder = [&](mlir::OpBuilder &nestedBuilder, mlir::Location nestedLoc, mlir::ValueRange blockArgs) { HLFHEOp hlfheOp = nestedBuilder.create( linalgOp.getLoc(), blockArgs[0], blockArgs[1]); nestedBuilder.create(linalgOp.getLoc(), hlfheOp.getResult()); }; // Create the `linalg.generic` op llvm::SmallVector resTypes{init.getType()}; llvm::SmallVector ins{linalgOp.lhs(), linalgOp.rhs()}; llvm::SmallVector outs{init}; llvm::StringRef doc{""}; llvm::StringRef call{""}; mlir::linalg::GenericOp genericOp = rewriter.create(linalgOp.getLoc(), resTypes, ins, outs, maps, iteratorTypes, doc, call, bodyBuilder); rewriter.replaceOp(linalgOp, {genericOp.getResult(0)}); return ::mlir::success(); }; }; // This template rewrite pattern transforms any instance of // operators `HLFHELinalg.apply_lookup_table` that implements the broadasting // rules to an instance of `linalg.generic` with an appropriate region using // `HLFHE.apply_lookup_table` operation, an appropriate specification for the // iteration dimensions and appropriate operations managing the accumulator of // `linalg.generic`. // // Example: // // HLFHELinalg.apply_lookup_table(%t, %lut): // tensor>, tensor // -> tensor> // // becomes: // // #maps_0 = [ // affine_map<(aN, ..., a1) -> (aN, ..., a1)>, // affine_map<(aN, ..., a1) -> (aN, ..., a1)> // ] // #attributes_0 { // indexing_maps = #maps_0, // iterator_types = ["parallel",..],//N parallel // } // %init = linalg.init_tensor [DN,...,D1] // : tensor> // %res = linalg.generic { // ins(%t: tensor>) // outs(%init : tensor>) // { // ^bb0(%arg0: !HLFHE.eint

): // %0 = HLFHE.apply_lookup_table(%arg0, %lut): !HLFHE.eint

, // tensor<4xi64> -> !HLFHE.eint // linalg.yield %0 : !HLFHE.eint // } // } // struct HLFHELinalgApplyLookupTableToLinalgGeneric : public mlir::OpRewritePattern< mlir::zamalang::HLFHELinalg::ApplyLookupTableEintOp> { HLFHELinalgApplyLookupTableToLinalgGeneric(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) : ::mlir::OpRewritePattern< mlir::zamalang::HLFHELinalg::ApplyLookupTableEintOp>(context, benefit) {} ::mlir::LogicalResult matchAndRewrite(mlir::zamalang::HLFHELinalg::ApplyLookupTableEintOp lutOp, ::mlir::PatternRewriter &rewriter) const override { mlir::RankedTensorType resultTy = ((mlir::Type)lutOp->getResult(0).getType()) .cast(); mlir::RankedTensorType tTy = ((mlir::Type)lutOp.t().getType()).cast(); // linalg.init_tensor for initial value mlir::Value init = rewriter.create( lutOp.getLoc(), resultTy.getShape(), resultTy.getElementType()); // Create the affine #maps_0 llvm::SmallVector maps{ mlir::AffineMap::getMultiDimIdentityMap(tTy.getShape().size(), this->getContext()), mlir::AffineMap::getMultiDimIdentityMap(resultTy.getShape().size(), this->getContext()), }; // Create the iterator_types llvm::SmallVector iteratorTypes(resultTy.getShape().size(), "parallel"); // Create the body of the `linalg.generic` op auto bodyBuilder = [&](mlir::OpBuilder &nestedBuilder, mlir::Location nestedLoc, mlir::ValueRange blockArgs) { mlir::zamalang::HLFHE::ApplyLookupTableEintOp hlfheOp = nestedBuilder.create( lutOp.getLoc(), resultTy.getElementType(), blockArgs[0], lutOp.lut()); nestedBuilder.create(lutOp.getLoc(), hlfheOp.getResult()); }; // Create the `linalg.generic` op llvm::SmallVector resTypes{init.getType()}; llvm::SmallVector ins{lutOp.t()}; llvm::SmallVector outs{init}; llvm::StringRef doc{""}; llvm::StringRef call{""}; mlir::linalg::GenericOp genericOp = rewriter.create(lutOp.getLoc(), resTypes, ins, outs, maps, iteratorTypes, doc, call, bodyBuilder); rewriter.replaceOp(lutOp, {genericOp.getResult(0)}); return ::mlir::success(); }; }; // This template rewrite pattern transforms any instance of // operators `HLFHELinalg.neg_eint` to an instance of `linalg.generic` with an // appropriate region using `HLFHE.neg_eint` operation, an appropriate // specification for the iteration dimensions and appropriate operations // managing the accumulator of `linalg.generic`. // // Example: // // HLFHELinalg.neg_eint(%tensor): // tensor> -> tensor> // // becomes: // // #maps_0 = [ // affine_map<(aN, ..., a1) -> (aN, ..., a1)>, // affine_map<(aN, ..., a1) -> (aN, ..., a1)> // ] // #attributes_0 { // indexing_maps = #maps_0, // iterator_types = ["parallel",..],//N parallel // } // %init = linalg.init_tensor [DN,...,D1] // : tensor> // %res = linalg.generic { // ins(%tensor: tensor>) // outs(%init : tensor>) // { // ^bb0(%arg0: !HLFHE.eint

): // %0 = HLFHE.neg_eint(%arg0): !HLFHE.eint

-> !HLFHE.eint // linalg.yield %0 : !HLFHE.eint // } // } // struct HLFHELinalgNegEintToLinalgGeneric : public mlir::OpRewritePattern { HLFHELinalgNegEintToLinalgGeneric(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) : ::mlir::OpRewritePattern( context, benefit) {} ::mlir::LogicalResult matchAndRewrite(mlir::zamalang::HLFHELinalg::NegEintOp negEintOp, ::mlir::PatternRewriter &rewriter) const override { mlir::RankedTensorType resultTy = ((mlir::Type)negEintOp->getResult(0).getType()) .cast(); mlir::RankedTensorType tensorTy = ((mlir::Type)negEintOp.tensor().getType()) .cast(); // linalg.init_tensor for initial value mlir::Value init = rewriter.create( negEintOp.getLoc(), resultTy.getShape(), resultTy.getElementType()); // Create the affine #maps_0 llvm::SmallVector maps{ mlir::AffineMap::getMultiDimIdentityMap(tensorTy.getShape().size(), this->getContext()), mlir::AffineMap::getMultiDimIdentityMap(resultTy.getShape().size(), this->getContext()), }; // Create the iterator_types llvm::SmallVector iteratorTypes(resultTy.getShape().size(), "parallel"); // Create the body of the `linalg.generic` op auto bodyBuilder = [&](mlir::OpBuilder &nestedBuilder, mlir::Location nestedLoc, mlir::ValueRange blockArgs) { mlir::zamalang::HLFHE::NegEintOp hlfheOp = nestedBuilder.create( negEintOp.getLoc(), resultTy.getElementType(), blockArgs[0]); nestedBuilder.create(negEintOp.getLoc(), hlfheOp.getResult()); }; // Create the `linalg.generic` op llvm::SmallVector resTypes{init.getType()}; llvm::SmallVector ins{negEintOp.tensor()}; llvm::SmallVector outs{init}; llvm::StringRef doc{""}; llvm::StringRef call{""}; mlir::linalg::GenericOp genericOp = rewriter.create(negEintOp.getLoc(), resTypes, ins, outs, maps, iteratorTypes, doc, call, bodyBuilder); rewriter.replaceOp(negEintOp, {genericOp.getResult(0)}); return ::mlir::success(); }; }; namespace { struct HLFHETensorOpsToLinalg : public HLFHETensorOpsToLinalgBase { void runOnFunction() final; }; void HLFHETensorOpsToLinalg::runOnFunction() { mlir::FuncOp function = this->getFunction(); mlir::ConversionTarget target(getContext()); target.addLegalDialect(); target.addLegalDialect(); target.addLegalDialect(); target.addLegalDialect(); target.addLegalDialect(); target.addLegalDialect(); target.addIllegalOp(); target.addIllegalDialect(); mlir::OwningRewritePatternList patterns(&getContext()); patterns.insert(&getContext()); patterns.insert< HLFHELinalgOpToLinalgGeneric>( &getContext()); patterns.insert< HLFHELinalgOpToLinalgGeneric>( &getContext()); patterns.insert< HLFHELinalgOpToLinalgGeneric>( &getContext()); patterns.insert< HLFHELinalgOpToLinalgGeneric>( &getContext()); patterns.insert(&getContext()); patterns.insert(&getContext()); if (mlir::applyPartialConversion(function, target, std::move(patterns)) .failed()) this->signalPassFailure(); } } // namespace namespace mlir { namespace zamalang { std::unique_ptr createConvertHLFHETensorOpsToLinalg() { return std::make_unique(); } } // namespace zamalang } // namespace mlir