mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
feat(compiler): Add passes for tiling of HLFHELinalg.matmul_eint_int
Add two passes related to the tiling of `HLFHELinalg.matmul_eint_int`
operations.
The `HLFHELinalgTilingMarker` pass takes a vector of tile sizes and
adds an integer array attribute "tile-sizes" to each instance of
`HLFHELinalg.matmul_eint_int`, e.g.,
"HLFHELinalg.matmul_eint_int"(%arg0, %arg1) {"tile-sizes" = [2, 2, 2]} :
(tensor<4x2x!HLFHE.eint<6>>, tensor<2x2xi7>) -> tensor<4x2x!HLFHE.eint<6>>
The `HLFHELinalgTiling` performs the actual tiling of each
`HLFHELinalg.matmul_eint_int` operation marked with a "tile-sizes"
attribute. The tiling preserves the level of abstraction of
HLFHELinalg and is implemented as a perfect loop nest of SCF for loops
with a `HLFHELinalg.matmul_eint_int` in the body.
For example,
func @main(%arg0: tensor<4x2x!HLFHE.eint<6>>, %arg1: tensor<2x2xi7>)
-> tensor<4x2x!HLFHE.eint<6>>
{
%0 = "HLFHELinalg.matmul_eint_int"(%arg0, %arg1) {"tile-sizes" = [2, 2, 2]} :
(tensor<4x2x!HLFHE.eint<6>>, tensor<2x2xi7>) -> tensor<4x2x!HLFHE.eint<6>>
return %0 : tensor<4x2x!HLFHE.eint<6>>
}
becomes:
func @main(%arg0: tensor<4x2x!HLFHE.eint<6>>, %arg1: tensor<2x2xi7>)
-> tensor<4x2x!HLFHE.eint<6>>
{
%c2 = arith.constant 2 : index
%c0 = arith.constant 0 : index
%c4 = arith.constant 4 : index
%0 = "HLFHELinalg.zero"() : () -> tensor<4x2x!HLFHE.eint<6>>
%1 = scf.for %arg2 = %c0 to %c4 step %c2 iter_args(%arg3 = %0)
-> (tensor<4x2x!HLFHE.eint<6>>) {
%2 = scf.for %arg4 = %c0 to %c2 step %c2 iter_args(%arg5 = %arg3)
-> (tensor<4x2x!HLFHE.eint<6>>) {
%3 = scf.for %arg6 = %c0 to %c2 step %c2 iter_args(%arg7 = %arg5)
-> (tensor<4x2x!HLFHE.eint<6>>) {
%4 = tensor.extract_slice %arg0[%arg2, %arg4] [2, 2] [1, 1] :
tensor<4x2x!HLFHE.eint<6>> to tensor<2x2x!HLFHE.eint<6>>
%5 = tensor.extract_slice %arg1[%arg4, %arg6] [2, 2] [1, 1] :
tensor<2x2xi7> to tensor<2x2xi7>
%6 = tensor.extract_slice %arg7[%arg2, %arg6] [2, 2] [1, 1] :
tensor<4x2x!HLFHE.eint<6>> to tensor<2x2x!HLFHE.eint<6>>
%7 = "HLFHELinalg.matmul_eint_int"(%4, %5) :
(tensor<2x2x!HLFHE.eint<6>>, tensor<2x2xi7>)
-> tensor<2x2x!HLFHE.eint<6>>
%8 = "HLFHELinalg.add_eint"(%6, %7) :
(tensor<2x2x!HLFHE.eint<6>>, tensor<2x2x!HLFHE.eint<6>>)
-> tensor<2x2x!HLFHE.eint<6>>
%9 = tensor.insert_slice %8 into %arg7[%arg2, %arg6] [2, 2] [1, 1] :
tensor<2x2x!HLFHE.eint<6>> into tensor<4x2x!HLFHE.eint<6>>
scf.yield %9 : tensor<4x2x!HLFHE.eint<6>>
}
scf.yield %3 : tensor<4x2x!HLFHE.eint<6>>
}
scf.yield %2 : tensor<4x2x!HLFHE.eint<6>>
}
return %1 : tensor<4x2x!HLFHE.eint<6>>
}
Only full tiles are supported, i.e., the size of the dimensions of the
operands must be a multiple of the respective tile sizes.
This commit is contained in:
@@ -1 +1,2 @@
|
||||
add_subdirectory(IR)
|
||||
add_subdirectory(Transforms)
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
set(LLVM_TARGET_DEFINITIONS Tiling.td)
|
||||
mlir_tablegen(Tiling.h.inc -gen-pass-decls -name Transforms)
|
||||
add_public_tablegen_target(ZamalangHLFHELinalgTilingPassIncGen)
|
||||
@@ -0,0 +1,19 @@
|
||||
#ifndef ZAMALANG_HLFHELINALG_TILING_PASS_H
|
||||
#define ZAMALANG_HLFHELINALG_TILING_PASS_H
|
||||
|
||||
#include <mlir/Pass/Pass.h>
|
||||
#include <zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgDialect.h>
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include <zamalang/Dialect/HLFHELinalg/Transforms/Tiling.h.inc>
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
std::unique_ptr<mlir::OperationPass<>>
|
||||
createHLFHELinalgTilingMarkerPass(llvm::ArrayRef<int64_t> tileSizes);
|
||||
|
||||
std::unique_ptr<mlir::OperationPass<>> createHLFHELinalgTilingPass();
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,22 @@
|
||||
#ifndef ZAMALANG_HLFHELINALG_TILING_PASS
|
||||
#define ZAMALANG_HLFHELINALG_TILING_PASS
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def HLFHELinalgTilingMarker : Pass<"hlfhe-linalg-tiling-marker"> {
|
||||
let summary =
|
||||
"Marks HLFHELinalg operations for tiling using a vector of tile sizes";
|
||||
let constructor = "mlir::zamalang::createHLFHELinalgTilingMarkerPass()";
|
||||
let options = [];
|
||||
let dependentDialects = [ "mlir::zamalang::HLFHELinalg::HLFHELinalgDialect" ];
|
||||
}
|
||||
|
||||
def HLFHELinalgTiling : Pass<"hlfhe-linalg-tiling"> {
|
||||
let summary = "Performs tiling of HLFHELinalg operations based on the "
|
||||
"tile-size attribute";
|
||||
let constructor = "mlir::zamalang::createHLFHELinalgTilingPass()";
|
||||
let options = [];
|
||||
let dependentDialects = [ "mlir::zamalang::HLFHELinalg::HLFHELinalgDialect" ];
|
||||
}
|
||||
|
||||
#endif
|
||||
Reference in New Issue
Block a user