diff --git a/compiler/src/main.cpp b/compiler/src/main.cpp index 13ccca867..ceeae6880 100644 --- a/compiler/src/main.cpp +++ b/compiler/src/main.cpp @@ -142,6 +142,11 @@ llvm::cl::opt, false, OptionalSizeTParser> assumeMaxMANP( llvm::cl::desc( "Assume a maximum for the Minimum Arithmetic Noise Padding")); +llvm::cl::list hlfhelinalgTileSizes( + "hlfhelinalg-tile-sizes", + llvm::cl::desc( + "Force tiling of HLFHELinalg operation with the given tile sizes"), + llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated); } // namespace cmdline llvm::Expected buildFHEContext( @@ -222,6 +227,7 @@ mlir::LogicalResult processInputBuffer( llvm::ArrayRef jitArgs, llvm::Optional overrideMaxEintPrecision, llvm::Optional overrideMaxMANP, bool verifyDiagnostics, + llvm::Optional> hlfhelinalgTileSizes, llvm::raw_ostream &os, std::shared_ptr outputLib) { std::shared_ptr ccx = @@ -244,6 +250,9 @@ mlir::LogicalResult processInputBuffer( if (overrideMaxMANP.hasValue()) ce.setMaxMANP(overrideMaxMANP.getValue()); + if (hlfhelinalgTileSizes.hasValue()) + ce.setHLFHELinalgTileSizes(*hlfhelinalgTileSizes); + if (action == Action::JIT_INVOKE) { llvm::Expected lambdaOrErr = ce.buildLambda(std::move(buffer), jitFuncName); @@ -354,6 +363,12 @@ mlir::LogicalResult compilerMain(int argc, char **argv) { } } + // Convert tile sizes to `Optional` + llvm::Optional> hlfhelinalgTileSizes; + + if (!cmdline::hlfhelinalgTileSizes.empty()) + hlfhelinalgTileSizes.emplace(cmdline::hlfhelinalgTileSizes); + // In case of compilation to library, the real output is the library. std::string outputPath = (cmdline::action == Action::COMPILE) ? cmdline::STDOUT : cmdline::output; @@ -388,7 +403,7 @@ mlir::LogicalResult compilerMain(int argc, char **argv) { std::move(inputBuffer), fileName, cmdline::action, cmdline::jitFuncName, cmdline::jitArgs, cmdline::assumeMaxEintPrecision, cmdline::assumeMaxMANP, - cmdline::verifyDiagnostics, os, outputLib); + cmdline::verifyDiagnostics, hlfhelinalgTileSizes, os, outputLib); }; auto &os = output->os(); auto res = mlir::failure();