feat(compiler): Add passes for tiling of HLFHELinalg.matmul_eint_int

Add two passes related to the tiling of `HLFHELinalg.matmul_eint_int`
operations.

The `HLFHELinalgTilingMarker` pass takes a vector of tile sizes and
adds an integer array attribute "tile-sizes" to each instance of
`HLFHELinalg.matmul_eint_int`, e.g.,

  "HLFHELinalg.matmul_eint_int"(%arg0, %arg1) {"tile-sizes" = [2, 2, 2]} :
    (tensor<4x2x!HLFHE.eint<6>>, tensor<2x2xi7>) -> tensor<4x2x!HLFHE.eint<6>>

The `HLFHELinalgTiling` performs the actual tiling of each
`HLFHELinalg.matmul_eint_int` operation marked with a "tile-sizes"
attribute. The tiling preserves the level of abstraction of
HLFHELinalg and is implemented as a perfect loop nest of SCF for loops
with a `HLFHELinalg.matmul_eint_int` in the body.

For example,

  func @main(%arg0: tensor<4x2x!HLFHE.eint<6>>, %arg1: tensor<2x2xi7>)
    -> tensor<4x2x!HLFHE.eint<6>>
  {
    %0 = "HLFHELinalg.matmul_eint_int"(%arg0, %arg1) {"tile-sizes" = [2, 2, 2]} :
           (tensor<4x2x!HLFHE.eint<6>>, tensor<2x2xi7>) -> tensor<4x2x!HLFHE.eint<6>>
    return %0 : tensor<4x2x!HLFHE.eint<6>>
  }

becomes:

  func @main(%arg0: tensor<4x2x!HLFHE.eint<6>>, %arg1: tensor<2x2xi7>)
    -> tensor<4x2x!HLFHE.eint<6>>
  {
    %c2 = arith.constant 2 : index
    %c0 = arith.constant 0 : index
    %c4 = arith.constant 4 : index

    %0 = "HLFHELinalg.zero"() : () -> tensor<4x2x!HLFHE.eint<6>>
    %1 = scf.for %arg2 = %c0 to %c4 step %c2 iter_args(%arg3 = %0)
           -> (tensor<4x2x!HLFHE.eint<6>>) {
      %2 = scf.for %arg4 = %c0 to %c2 step %c2 iter_args(%arg5 = %arg3)
             -> (tensor<4x2x!HLFHE.eint<6>>) {
        %3 = scf.for %arg6 = %c0 to %c2 step %c2 iter_args(%arg7 = %arg5)
	       -> (tensor<4x2x!HLFHE.eint<6>>) {
          %4 = tensor.extract_slice %arg0[%arg2, %arg4] [2, 2] [1, 1] :
	         tensor<4x2x!HLFHE.eint<6>> to tensor<2x2x!HLFHE.eint<6>>
          %5 = tensor.extract_slice %arg1[%arg4, %arg6] [2, 2] [1, 1] :
	         tensor<2x2xi7> to tensor<2x2xi7>
          %6 = tensor.extract_slice %arg7[%arg2, %arg6] [2, 2] [1, 1] :
	         tensor<4x2x!HLFHE.eint<6>> to tensor<2x2x!HLFHE.eint<6>>

          %7 = "HLFHELinalg.matmul_eint_int"(%4, %5) :
	         (tensor<2x2x!HLFHE.eint<6>>, tensor<2x2xi7>)
		 -> tensor<2x2x!HLFHE.eint<6>>

          %8 = "HLFHELinalg.add_eint"(%6, %7) :
	         (tensor<2x2x!HLFHE.eint<6>>, tensor<2x2x!HLFHE.eint<6>>)
		 -> tensor<2x2x!HLFHE.eint<6>>

          %9 = tensor.insert_slice %8 into %arg7[%arg2, %arg6] [2, 2] [1, 1] :
	         tensor<2x2x!HLFHE.eint<6>> into tensor<4x2x!HLFHE.eint<6>>

          scf.yield %9 : tensor<4x2x!HLFHE.eint<6>>
        }
        scf.yield %3 : tensor<4x2x!HLFHE.eint<6>>
      }
      scf.yield %2 : tensor<4x2x!HLFHE.eint<6>>
    }
    return %1 : tensor<4x2x!HLFHE.eint<6>>
  }

Only full tiles are supported, i.e., the size of the dimensions of the
operands must be a multiple of the respective tile sizes.
This commit is contained in:
Andi Drebes
2021-12-15 11:21:21 +01:00
parent 77b7aa2f7c
commit bc75831c86
6 changed files with 422 additions and 0 deletions

View File

@@ -1 +1,2 @@
add_subdirectory(IR)
add_subdirectory(Transforms)

View File

@@ -0,0 +1,3 @@
set(LLVM_TARGET_DEFINITIONS Tiling.td)
mlir_tablegen(Tiling.h.inc -gen-pass-decls -name Transforms)
add_public_tablegen_target(ZamalangHLFHELinalgTilingPassIncGen)

View File

@@ -0,0 +1,19 @@
#ifndef ZAMALANG_HLFHELINALG_TILING_PASS_H
#define ZAMALANG_HLFHELINALG_TILING_PASS_H
#include <mlir/Pass/Pass.h>
#include <zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgDialect.h>
#define GEN_PASS_CLASSES
#include <zamalang/Dialect/HLFHELinalg/Transforms/Tiling.h.inc>
namespace mlir {
namespace zamalang {
std::unique_ptr<mlir::OperationPass<>>
createHLFHELinalgTilingMarkerPass(llvm::ArrayRef<int64_t> tileSizes);
std::unique_ptr<mlir::OperationPass<>> createHLFHELinalgTilingPass();
} // namespace zamalang
} // namespace mlir
#endif

View File

@@ -0,0 +1,22 @@
#ifndef ZAMALANG_HLFHELINALG_TILING_PASS
#define ZAMALANG_HLFHELINALG_TILING_PASS
include "mlir/Pass/PassBase.td"
def HLFHELinalgTilingMarker : Pass<"hlfhe-linalg-tiling-marker"> {
let summary =
"Marks HLFHELinalg operations for tiling using a vector of tile sizes";
let constructor = "mlir::zamalang::createHLFHELinalgTilingMarkerPass()";
let options = [];
let dependentDialects = [ "mlir::zamalang::HLFHELinalg::HLFHELinalgDialect" ];
}
def HLFHELinalgTiling : Pass<"hlfhe-linalg-tiling"> {
let summary = "Performs tiling of HLFHELinalg operations based on the "
"tile-size attribute";
let constructor = "mlir::zamalang::createHLFHELinalgTilingPass()";
let options = [];
let dependentDialects = [ "mlir::zamalang::HLFHELinalg::HLFHELinalgDialect" ];
}
#endif