From bc75831c867d02b0c2974411edede61e14f6949d Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Wed, 15 Dec 2021 11:21:21 +0100 Subject: [PATCH] feat(compiler): Add passes for tiling of HLFHELinalg.matmul_eint_int Add two passes related to the tiling of `HLFHELinalg.matmul_eint_int` operations. The `HLFHELinalgTilingMarker` pass takes a vector of tile sizes and adds an integer array attribute "tile-sizes" to each instance of `HLFHELinalg.matmul_eint_int`, e.g., "HLFHELinalg.matmul_eint_int"(%arg0, %arg1) {"tile-sizes" = [2, 2, 2]} : (tensor<4x2x!HLFHE.eint<6>>, tensor<2x2xi7>) -> tensor<4x2x!HLFHE.eint<6>> The `HLFHELinalgTiling` performs the actual tiling of each `HLFHELinalg.matmul_eint_int` operation marked with a "tile-sizes" attribute. The tiling preserves the level of abstraction of HLFHELinalg and is implemented as a perfect loop nest of SCF for loops with a `HLFHELinalg.matmul_eint_int` in the body. For example, func @main(%arg0: tensor<4x2x!HLFHE.eint<6>>, %arg1: tensor<2x2xi7>) -> tensor<4x2x!HLFHE.eint<6>> { %0 = "HLFHELinalg.matmul_eint_int"(%arg0, %arg1) {"tile-sizes" = [2, 2, 2]} : (tensor<4x2x!HLFHE.eint<6>>, tensor<2x2xi7>) -> tensor<4x2x!HLFHE.eint<6>> return %0 : tensor<4x2x!HLFHE.eint<6>> } becomes: func @main(%arg0: tensor<4x2x!HLFHE.eint<6>>, %arg1: tensor<2x2xi7>) -> tensor<4x2x!HLFHE.eint<6>> { %c2 = arith.constant 2 : index %c0 = arith.constant 0 : index %c4 = arith.constant 4 : index %0 = "HLFHELinalg.zero"() : () -> tensor<4x2x!HLFHE.eint<6>> %1 = scf.for %arg2 = %c0 to %c4 step %c2 iter_args(%arg3 = %0) -> (tensor<4x2x!HLFHE.eint<6>>) { %2 = scf.for %arg4 = %c0 to %c2 step %c2 iter_args(%arg5 = %arg3) -> (tensor<4x2x!HLFHE.eint<6>>) { %3 = scf.for %arg6 = %c0 to %c2 step %c2 iter_args(%arg7 = %arg5) -> (tensor<4x2x!HLFHE.eint<6>>) { %4 = tensor.extract_slice %arg0[%arg2, %arg4] [2, 2] [1, 1] : tensor<4x2x!HLFHE.eint<6>> to tensor<2x2x!HLFHE.eint<6>> %5 = tensor.extract_slice %arg1[%arg4, %arg6] [2, 2] [1, 1] : tensor<2x2xi7> to tensor<2x2xi7> %6 = tensor.extract_slice %arg7[%arg2, %arg6] [2, 2] [1, 1] : tensor<4x2x!HLFHE.eint<6>> to tensor<2x2x!HLFHE.eint<6>> %7 = "HLFHELinalg.matmul_eint_int"(%4, %5) : (tensor<2x2x!HLFHE.eint<6>>, tensor<2x2xi7>) -> tensor<2x2x!HLFHE.eint<6>> %8 = "HLFHELinalg.add_eint"(%6, %7) : (tensor<2x2x!HLFHE.eint<6>>, tensor<2x2x!HLFHE.eint<6>>) -> tensor<2x2x!HLFHE.eint<6>> %9 = tensor.insert_slice %8 into %arg7[%arg2, %arg6] [2, 2] [1, 1] : tensor<2x2x!HLFHE.eint<6>> into tensor<4x2x!HLFHE.eint<6>> scf.yield %9 : tensor<4x2x!HLFHE.eint<6>> } scf.yield %3 : tensor<4x2x!HLFHE.eint<6>> } scf.yield %2 : tensor<4x2x!HLFHE.eint<6>> } return %1 : tensor<4x2x!HLFHE.eint<6>> } Only full tiles are supported, i.e., the size of the dimensions of the operands must be a multiple of the respective tile sizes. --- .../Dialect/HLFHELinalg/CMakeLists.txt | 1 + .../HLFHELinalg/Transforms/CMakeLists.txt | 3 + .../Dialect/HLFHELinalg/Transforms/Tiling.h | 19 + .../Dialect/HLFHELinalg/Transforms/Tiling.td | 22 ++ .../HLFHELinalg/Transforms/CMakeLists.txt | 15 + .../Dialect/HLFHELinalg/Transforms/Tiling.cpp | 362 ++++++++++++++++++ 6 files changed, 422 insertions(+) create mode 100644 compiler/include/zamalang/Dialect/HLFHELinalg/Transforms/CMakeLists.txt create mode 100644 compiler/include/zamalang/Dialect/HLFHELinalg/Transforms/Tiling.h create mode 100644 compiler/include/zamalang/Dialect/HLFHELinalg/Transforms/Tiling.td create mode 100644 compiler/lib/Dialect/HLFHELinalg/Transforms/CMakeLists.txt create mode 100644 compiler/lib/Dialect/HLFHELinalg/Transforms/Tiling.cpp diff --git a/compiler/include/zamalang/Dialect/HLFHELinalg/CMakeLists.txt b/compiler/include/zamalang/Dialect/HLFHELinalg/CMakeLists.txt index f33061b2d..9f57627c3 100644 --- a/compiler/include/zamalang/Dialect/HLFHELinalg/CMakeLists.txt +++ b/compiler/include/zamalang/Dialect/HLFHELinalg/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/compiler/include/zamalang/Dialect/HLFHELinalg/Transforms/CMakeLists.txt b/compiler/include/zamalang/Dialect/HLFHELinalg/Transforms/CMakeLists.txt new file mode 100644 index 000000000..afccddaa5 --- /dev/null +++ b/compiler/include/zamalang/Dialect/HLFHELinalg/Transforms/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Tiling.td) +mlir_tablegen(Tiling.h.inc -gen-pass-decls -name Transforms) +add_public_tablegen_target(ZamalangHLFHELinalgTilingPassIncGen) diff --git a/compiler/include/zamalang/Dialect/HLFHELinalg/Transforms/Tiling.h b/compiler/include/zamalang/Dialect/HLFHELinalg/Transforms/Tiling.h new file mode 100644 index 000000000..a109da047 --- /dev/null +++ b/compiler/include/zamalang/Dialect/HLFHELinalg/Transforms/Tiling.h @@ -0,0 +1,19 @@ +#ifndef ZAMALANG_HLFHELINALG_TILING_PASS_H +#define ZAMALANG_HLFHELINALG_TILING_PASS_H + +#include +#include + +#define GEN_PASS_CLASSES +#include + +namespace mlir { +namespace zamalang { +std::unique_ptr> +createHLFHELinalgTilingMarkerPass(llvm::ArrayRef tileSizes); + +std::unique_ptr> createHLFHELinalgTilingPass(); +} // namespace zamalang +} // namespace mlir + +#endif diff --git a/compiler/include/zamalang/Dialect/HLFHELinalg/Transforms/Tiling.td b/compiler/include/zamalang/Dialect/HLFHELinalg/Transforms/Tiling.td new file mode 100644 index 000000000..cd1210cd4 --- /dev/null +++ b/compiler/include/zamalang/Dialect/HLFHELinalg/Transforms/Tiling.td @@ -0,0 +1,22 @@ +#ifndef ZAMALANG_HLFHELINALG_TILING_PASS +#define ZAMALANG_HLFHELINALG_TILING_PASS + +include "mlir/Pass/PassBase.td" + +def HLFHELinalgTilingMarker : Pass<"hlfhe-linalg-tiling-marker"> { + let summary = + "Marks HLFHELinalg operations for tiling using a vector of tile sizes"; + let constructor = "mlir::zamalang::createHLFHELinalgTilingMarkerPass()"; + let options = []; + let dependentDialects = [ "mlir::zamalang::HLFHELinalg::HLFHELinalgDialect" ]; +} + +def HLFHELinalgTiling : Pass<"hlfhe-linalg-tiling"> { + let summary = "Performs tiling of HLFHELinalg operations based on the " + "tile-size attribute"; + let constructor = "mlir::zamalang::createHLFHELinalgTilingPass()"; + let options = []; + let dependentDialects = [ "mlir::zamalang::HLFHELinalg::HLFHELinalgDialect" ]; +} + +#endif diff --git a/compiler/lib/Dialect/HLFHELinalg/Transforms/CMakeLists.txt b/compiler/lib/Dialect/HLFHELinalg/Transforms/CMakeLists.txt new file mode 100644 index 000000000..723fb698c --- /dev/null +++ b/compiler/lib/Dialect/HLFHELinalg/Transforms/CMakeLists.txt @@ -0,0 +1,15 @@ +add_mlir_library(HLFHELinalgDialectTransforms + Tiling.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/zamalang/Dialect/HLFHELinalg + + DEPENDS + HLFHELinalgDialect + ZamalangHLFHELinalgTilingPassIncGen + + LINK_LIBS PUBLIC + MLIRIR + HLFHELinalgDialect) + +target_link_libraries(HLFHELinalgDialectTransforms PUBLIC MLIRIR) diff --git a/compiler/lib/Dialect/HLFHELinalg/Transforms/Tiling.cpp b/compiler/lib/Dialect/HLFHELinalg/Transforms/Tiling.cpp new file mode 100644 index 000000000..c92bdbd19 --- /dev/null +++ b/compiler/lib/Dialect/HLFHELinalg/Transforms/Tiling.cpp @@ -0,0 +1,362 @@ +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace mlir { +namespace zamalang { + +namespace { + +// Creates a `tensor.extract_slice` operation that extracts a +// contiguous, 2-dimensional slice with a static size specified by +// `sizes` at the dynamic offset `offsets`. +mlir::tensor::ExtractSliceOp +extractContiguous2DSlice(mlir::OpBuilder &builder, mlir::Location loc, + mlir::Value T, llvm::ArrayRef sizes, + llvm::ArrayRef offsets) { + assert(sizes.size() == 2 && offsets.size() == 2 && + "The number of dimensions for the size and offset must be 2"); + + mlir::Type elTy = T.getType().cast().getElementType(); + + return builder.create( + loc, mlir::RankedTensorType::get(sizes, elTy), T, offsets, + llvm::SmallVector{ + builder.getI64IntegerAttr(sizes[0]), + builder.getI64IntegerAttr(sizes[1])}, + llvm::SmallVector{builder.getI64IntegerAttr(1), + builder.getI64IntegerAttr(1)}); +} + +// Creates a perfect loop nest of SCF for loops with the lower bounds +// `lbs`, the upper bounds `ubs` and the steps `steps` in the order +// from the outermost to the innermost loop. The values specified in +// `loopCarriedDeps` are loop-carried dependencies carried across all +// loops. +// +// The function `func` is called with a builder for the body of the +// innermost loop, the original location `loc`, a vector with all +// induction variables from the outermost to the innermost loop and the +// loop-carried dependencies. +// +// Returns the outermost loop. +mlir::scf::ForOp buildLoopNestWithLoopCarriedDependency( + mlir::OpBuilder builder, mlir::Location loc, + llvm::ArrayRef lbs, llvm::ArrayRef ubs, + llvm::ArrayRef steps, + llvm::ArrayRef loopCarriedDeps, + function_ref func = + nullptr) { + + size_t nLoops = lbs.size(); + + assert(nLoops > 0 && ubs.size() == nLoops && steps.size() == nLoops && + "Attempting to build loop nest with incomplete specification"); + + llvm::SmallVector loopCarriedDepsUpd(loopCarriedDeps.begin(), + loopCarriedDeps.end()); + llvm::SmallVector inductionVars; + llvm::SmallVector fops; + + // Create the loops and construct body of the innermost loop using the + // callback function + for (size_t i = 0; i < nLoops; i++) { + mlir::scf::ForOp fop = builder.create( + loc, lbs[i], ubs[i], steps[i], loopCarriedDepsUpd, + + [&](mlir::OpBuilder &builder, mlir::Location location, + mlir::Value indVar, mlir::ValueRange iterArgs) -> void { + loopCarriedDepsUpd = iterArgs; + inductionVars.push_back(indVar); + + mlir::OpBuilder opb(builder); + + if (i == nLoops - 1 && func) + func(opb, location, inductionVars, iterArgs); + }); + + builder.setInsertionPoint(fop.getBody(), fop.getBody()->end()); + + fops.push_back(fop); + } + + // Return updated loop-carried dependencies via scf.yield operations + for (size_t i = 0; i < nLoops - 1; i++) { + builder.setInsertionPoint(fops[i].getBody(), fops[i].getBody()->end()); + builder.create(loc, fops[i + 1].getResults()); + } + + return fops[0]; +} + +// Marker to avoid infinite recursion of the rewriting pattern +static const mlir::StringLiteral kTransformMarker = + "__internal_hlfhe_linalg_tiling_marker__"; + +// Rewrite an `HLFHELinalg.matmul_eint_int` operation as an equivalent +// sequence of operations consisting of a perfect loop nest of SCF for +// loops with a `HLFHELinalg.matmul_eint_int` operation that performs +// a matrix multiplication on a single tile. +// +// The terminology is as follows: +// +// - A: The input matrix of encrypted integers of size `NxM` +// - B: The input matrix of plaintext integers of size `MxK` +// - C: The output matrix of encrypted integers of size `NxK` +// +// At each iteration of the innermost loop, the generated +// `HLFHELinalg.matmul_eint_int` operation performs a multiplication +// of a matrix tile of size `TxU` and a matrix of size `UxV`, +// producing a tile of size `UxV`. +// +// Partial tiles are currently not supported, i.e., `N` must be a +// multiple of `T`, `M` a multiple of `U` and `K` a multiple of `V`. +class MatMulTilingPattern : public mlir::OpRewritePattern< + mlir::zamalang::HLFHELinalg::MatMulEintIntOp> { +public: + MatMulTilingPattern(mlir::MLIRContext *context) + : mlir::OpRewritePattern( + context, ::mlir::zamalang::DEFAULT_PATTERN_BENEFIT) {} + + mlir::LogicalResult + matchAndRewrite(mlir::zamalang::HLFHELinalg::MatMulEintIntOp op, + mlir::PatternRewriter &rewriter) const override { + // Avoid infinite recursion by marking each matmul operation and + // bailing out for the marker + if (op->hasAttr(kTransformMarker)) + return mlir::failure(); + + // Only tile operations that are explicitly marked for tiling with + // tile sizes + if (!op->hasAttr("tile-sizes")) + return mlir::failure(); + + // Original location of the operation to be replaced with the + // tiling + mlir::Location origLoc = op->getLoc(); + + mlir::ArrayAttr tileSizes = + op->getAttrOfType("tile-sizes"); + + if (!tileSizes) { + op->emitError("Wrong type for attribute \"tile-size\""); + return mlir::failure(); + } + + if (tileSizes.size() != 3) { + op->emitError("Need 3 tile sizes, but got ") << tileSizes.size(); + return mlir::failure(); + } + + // Extract tile sizes + mlir::IntegerAttr attrT = + tileSizes[0].dyn_cast_or_null(); + mlir::IntegerAttr attrU = + tileSizes[1].dyn_cast_or_null(); + mlir::IntegerAttr attrV = + tileSizes[2].dyn_cast_or_null(); + + if (!attrT || !attrU || !attrV) { + op->emitError("Wrong type for tile sizes"); + return mlir::failure(); + } + + mlir::OpBuilder::InsertionGuard guard(rewriter); + rewriter.startRootUpdate(op); + rewriter.setInsertionPointAfter(op); + + // Plain integer tile sizes + int64_t iT = attrT.getInt(); + int64_t iU = attrU.getInt(); + int64_t iV = attrV.getInt(); + + mlir::Value A = op.getOperand(0); + mlir::Value B = op.getOperand(1); + + // Initialization of the output matrix with zeros + mlir::zamalang::HLFHELinalg::ZeroOp Cinit = + rewriter.create( + origLoc, op.getResult().getType()); + + mlir::TensorType ATTy = A.getType().cast(); + mlir::TensorType BTTy = B.getType().cast(); + mlir::TensorType CTTy = + Cinit.getResult().getType().cast(); + + if (!ATTy.hasStaticShape() || !BTTy.hasStaticShape() || + !CTTy.hasStaticShape()) { + op.emitError() << "Can only tile matrix multiplications on statically " + "shaped tensors"; + return mlir::failure(); + } + + // Check that no partial tiles are necessary + if (ATTy.getDimSize(0) % iT != 0 || ATTy.getDimSize(1) % iU != 0 || + BTTy.getDimSize(1) % iV != 0) { + op.emitError() << "Dimensions of the tensors must be a multiple of the " + "tile size. Partial tiles are currently not supported."; + return mlir::failure(); + } + + mlir::arith::ConstantIndexOp T = + rewriter.create(origLoc, iT); + + mlir::arith::ConstantIndexOp U = + rewriter.create(origLoc, iU); + + mlir::arith::ConstantIndexOp V = + rewriter.create(origLoc, iV); + + // Lower bound for all for loops + mlir::arith::ConstantIndexOp lb = + rewriter.create(origLoc, 0); + + // Upper bounds are determined by the size of the operands + mlir::arith::ConstantIndexOp ubT = + rewriter.create(origLoc, + ATTy.getDimSize(0)); + + mlir::arith::ConstantIndexOp ubU = + rewriter.create(origLoc, + ATTy.getDimSize(1)); + + mlir::arith::ConstantIndexOp ubV = + rewriter.create(origLoc, + BTTy.getDimSize(1)); + + // Bounds and steps in vector form + llvm::SmallVector lbs{lb, lb, lb}; + llvm::SmallVector ubs{ubT, ubU, ubV}; + llvm::SmallVector steps{T, U, V}; + + // Callback function to build the body of the innermost loop + auto innermostBodyBuilder = [&](mlir::OpBuilder &builder, + mlir::Location location, + mlir::ValueRange inductionVars, + mlir::ValueRange iterArgs) { + // TxU tile from A + mlir::tensor::ExtractSliceOp ATile = extractContiguous2DSlice( + builder, origLoc, A, {iT, iU}, {inductionVars[0], inductionVars[1]}); + // UxV tile from B + mlir::tensor::ExtractSliceOp BTile = extractContiguous2DSlice( + builder, origLoc, B, {iU, iV}, {inductionVars[1], inductionVars[2]}); + + // TxV tile from C + mlir::tensor::ExtractSliceOp CTile = extractContiguous2DSlice( + builder, origLoc, *iterArgs.begin(), {iT, iV}, + {inductionVars[0], inductionVars[2]}); + + // Multiplication of the tiles + mlir::zamalang::HLFHELinalg::MatMulEintIntOp tiledMul = + builder.create( + origLoc, + mlir::RankedTensorType::get(llvm::SmallVector{iT, iV}, + CTTy.getElementType()), + ATile, BTile); + + // Mark matrix multiplication to prevent recursive + // application of the rewriting pattern + tiledMul.getOperation()->setAttr(kTransformMarker, + rewriter.getUnitAttr()); + + // Add result of the multiplication of the tiles to the + // result tile from C + mlir::zamalang::HLFHELinalg::AddEintOp accuTile = + builder.create(origLoc, CTile, + tiledMul); + + // Write updated C tile back into C + mlir::tensor::InsertSliceOp Cupdated = + builder.create( + origLoc, accuTile, *iterArgs.begin(), + + llvm::SmallVector{inductionVars[0], + inductionVars[2]}, + + llvm::SmallVector{ + rewriter.getI64IntegerAttr(iT), + rewriter.getI64IntegerAttr(iV)}, + + llvm::SmallVector{ + rewriter.getI64IntegerAttr(1), + rewriter.getI64IntegerAttr(1)}); + + builder.create(origLoc, Cupdated.getResult()); + }; + + mlir::scf::ForOp outermost = buildLoopNestWithLoopCarriedDependency( + rewriter, origLoc, lbs, ubs, steps, Cinit.getResult(), + innermostBodyBuilder); + + rewriter.replaceOp(op, outermost.getResult(0)); + + rewriter.finalizeRootUpdate(op); + + return mlir::success(); + } +}; + +// Perfoms the actual tiling of `HLFHELinalg.matmul_eint_int` +// operations that have been marked with a "tile-sizes" attribute. +class HLFHELinalgTilingPass + : public HLFHELinalgTilingBase { +public: + void runOnOperation() override { + mlir::Operation *op = getOperation(); + + mlir::RewritePatternSet patterns(op->getContext()); + patterns.add(op->getContext()); + + if (mlir::applyPatternsAndFoldGreedily(op, std::move(patterns)).failed()) { + this->signalPassFailure(); + } + + op->walk([](mlir::zamalang::HLFHELinalg::MatMulEintIntOp matmulOp) { + matmulOp.getOperation()->removeAttr(kTransformMarker); + }); + } +}; + +// Marks all `HLFHELinalg.matmul_eint_int` operations that with a +// "tile-sizes" attribute containing the specified tile sizes. +class HLFHELinalgTilingMarkerPass + : public HLFHELinalgTilingMarkerBase { +public: + HLFHELinalgTilingMarkerPass(llvm::ArrayRef tileSizes) + : tileSizes(tileSizes.vec()) {} + + void runOnOperation() override { + mlir::Operation *op = getOperation(); + + mlir::ArrayAttr tileAttr = + mlir::Builder(&this->getContext()).getI64ArrayAttr(tileSizes); + + op->walk([&](mlir::zamalang::HLFHELinalg::MatMulEintIntOp matmulOp) { + matmulOp.getOperation()->setAttr("tile-sizes", tileAttr); + }); + } + +protected: + std::vector tileSizes; +}; +} // end anonymous namespace + +std::unique_ptr> createHLFHELinalgTilingPass() { + return std::make_unique(); +} + +std::unique_ptr> +createHLFHELinalgTilingMarkerPass(llvm::ArrayRef tileSizes) { + return std::make_unique(tileSizes); +} +} // namespace zamalang +} // namespace mlir