feat(compiler): Add option --convert-hlfhe-tensor-ops-to-linalg

This adds a new command line option
`--convert-hlfhe-tensor-ops-to-linalg` that invokes a conversion pass
replacing any HLFHE tensor operation with an appropriate instance of
`linalg.generic`.
This commit is contained in:
Andi Drebes
2021-07-06 11:09:08 +02:00
committed by Quentin Bourgerie
parent 9d1cdc6a0c
commit 4504f090c5
3 changed files with 52 additions and 6 deletions

View File

@@ -5,12 +5,14 @@
#include <llvm/Support/ToolOutputFile.h>
#include <mlir/Dialect/MemRef/IR/MemRef.h>
#include <mlir/Parser.h>
#include <mlir/Pass/PassManager.h>
#include <mlir/Support/FileUtilities.h>
#include <mlir/Support/LogicalResult.h>
#include <mlir/Support/ToolUtilities.h>
#include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h"
#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h"
#include "zamalang/Dialect/HLFHE/Transforms/TensorOpsToLinalg.h"
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h"
#include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h"
@@ -24,6 +26,10 @@ llvm::cl::opt<std::string> output("o",
llvm::cl::value_desc("filename"),
llvm::cl::init("-"));
llvm::cl::opt<bool> convertHLFHETensorOpsToLinalg(
"convert-hlfhe-tensor-ops-to-linalg",
llvm::cl::desc("Convert HLFHE tensor operations to linalg operations"));
llvm::cl::opt<bool> verifyDiagnostics(
"verify-diagnostics",
llvm::cl::desc("Check that emitted diagnostics match "
@@ -44,11 +50,15 @@ llvm::cl::opt<bool> splitInputFile(
// `expected-error` are produced.
//
// If `verifyDiagnostics` is `false`, the procedure checks if the
// parsed module is valid.
// parsed module is valid and if all requested transformations
// succeeded.
mlir::LogicalResult
processInputBuffer(mlir::MLIRContext &context,
std::unique_ptr<llvm::MemoryBuffer> buffer,
llvm::raw_ostream &os, bool verifyDiagnostics) {
llvm::raw_ostream &os, bool verifyDiagnostics,
bool convertHLFHETensorOpsToLinalg) {
mlir::PassManager pm(&context);
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc());
@@ -63,6 +73,16 @@ processInputBuffer(mlir::MLIRContext &context,
if (!module)
return mlir::failure();
if (convertHLFHETensorOpsToLinalg) {
pm.addNestedPass<mlir::FuncOp>(
mlir::zamalang::HLFHE::createLowerTensorOpsToLinalgPass());
}
if (pm.run(*module).failed()) {
llvm::errs() << "Could not run passes!\n";
return mlir::failure();
}
module->print(os);
return mlir::success();
@@ -112,14 +132,17 @@ mlir::LogicalResult compilerMain(int argc, char **argv) {
std::move(file),
[&](std::unique_ptr<llvm::MemoryBuffer> inputBuffer,
llvm::raw_ostream &os) {
return processInputBuffer(context, std::move(inputBuffer), os,
cmdline::verifyDiagnostics);
return processInputBuffer(
context, std::move(inputBuffer), os,
cmdline::verifyDiagnostics,
cmdline::convertHLFHETensorOpsToLinalg);
},
output->os())))
return mlir::failure();
} else {
return processInputBuffer(context, std::move(file), output->os(),
cmdline::verifyDiagnostics);
cmdline::verifyDiagnostics,
cmdline::convertHLFHETensorOpsToLinalg);
}
}