feat(compiler): Add option --hlfhelinalg-tile-sizes for global HLFHELinalg tile sizes

Add a new option --hlfhelinalg-tile-sizes that forces tiling of
HLFHELinalg operations and that sets the tile sizes to the sizes given
in the parameter. The specification of the tile sizes is a
comma-separated list of integers, e.g.,

  --hlfhelinalg-tile-sizes=2,2,2

forces to use tiles of size 2 in each dimension.
This commit is contained in:
Andi Drebes
2021-12-15 16:57:03 +01:00
parent f319ba37d2
commit 7010a509d2

View File

@@ -142,6 +142,11 @@ llvm::cl::opt<llvm::Optional<size_t>, false, OptionalSizeTParser> assumeMaxMANP(
llvm::cl::desc(
"Assume a maximum for the Minimum Arithmetic Noise Padding"));
llvm::cl::list<int64_t> hlfhelinalgTileSizes(
"hlfhelinalg-tile-sizes",
llvm::cl::desc(
"Force tiling of HLFHELinalg operation with the given tile sizes"),
llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated);
} // namespace cmdline
llvm::Expected<mlir::zamalang::V0FHEContext> buildFHEContext(
@@ -222,6 +227,7 @@ mlir::LogicalResult processInputBuffer(
llvm::ArrayRef<uint64_t> jitArgs,
llvm::Optional<size_t> overrideMaxEintPrecision,
llvm::Optional<size_t> overrideMaxMANP, bool verifyDiagnostics,
llvm::Optional<llvm::ArrayRef<int64_t>> hlfhelinalgTileSizes,
llvm::raw_ostream &os,
std::shared_ptr<mlir::zamalang::CompilerEngine::Library> outputLib) {
std::shared_ptr<mlir::zamalang::CompilationContext> ccx =
@@ -244,6 +250,9 @@ mlir::LogicalResult processInputBuffer(
if (overrideMaxMANP.hasValue())
ce.setMaxMANP(overrideMaxMANP.getValue());
if (hlfhelinalgTileSizes.hasValue())
ce.setHLFHELinalgTileSizes(*hlfhelinalgTileSizes);
if (action == Action::JIT_INVOKE) {
llvm::Expected<mlir::zamalang::JitCompilerEngine::Lambda> lambdaOrErr =
ce.buildLambda(std::move(buffer), jitFuncName);
@@ -354,6 +363,12 @@ mlir::LogicalResult compilerMain(int argc, char **argv) {
}
}
// Convert tile sizes to `Optional`
llvm::Optional<llvm::ArrayRef<int64_t>> hlfhelinalgTileSizes;
if (!cmdline::hlfhelinalgTileSizes.empty())
hlfhelinalgTileSizes.emplace(cmdline::hlfhelinalgTileSizes);
// In case of compilation to library, the real output is the library.
std::string outputPath =
(cmdline::action == Action::COMPILE) ? cmdline::STDOUT : cmdline::output;
@@ -388,7 +403,7 @@ mlir::LogicalResult compilerMain(int argc, char **argv) {
std::move(inputBuffer), fileName, cmdline::action,
cmdline::jitFuncName, cmdline::jitArgs,
cmdline::assumeMaxEintPrecision, cmdline::assumeMaxMANP,
cmdline::verifyDiagnostics, os, outputLib);
cmdline::verifyDiagnostics, hlfhelinalgTileSizes, os, outputLib);
};
auto &os = output->os();
auto res = mlir::failure();