mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
feat(compiler): Add option --convert-hlfhe-tensor-ops-to-linalg
This adds a new command line option `--convert-hlfhe-tensor-ops-to-linalg` that invokes a conversion pass replacing any HLFHE tensor operation with an appropriate instance of `linalg.generic`.
This commit is contained in:
committed by
Quentin Bourgerie
parent
9d1cdc6a0c
commit
4504f090c5
@@ -5,12 +5,14 @@
|
||||
#include <llvm/Support/ToolOutputFile.h>
|
||||
#include <mlir/Dialect/MemRef/IR/MemRef.h>
|
||||
#include <mlir/Parser.h>
|
||||
#include <mlir/Pass/PassManager.h>
|
||||
#include <mlir/Support/FileUtilities.h>
|
||||
#include <mlir/Support/LogicalResult.h>
|
||||
#include <mlir/Support/ToolUtilities.h>
|
||||
|
||||
#include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h"
|
||||
#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h"
|
||||
#include "zamalang/Dialect/HLFHE/Transforms/TensorOpsToLinalg.h"
|
||||
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h"
|
||||
#include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h"
|
||||
|
||||
@@ -24,6 +26,10 @@ llvm::cl::opt<std::string> output("o",
|
||||
llvm::cl::value_desc("filename"),
|
||||
llvm::cl::init("-"));
|
||||
|
||||
llvm::cl::opt<bool> convertHLFHETensorOpsToLinalg(
|
||||
"convert-hlfhe-tensor-ops-to-linalg",
|
||||
llvm::cl::desc("Convert HLFHE tensor operations to linalg operations"));
|
||||
|
||||
llvm::cl::opt<bool> verifyDiagnostics(
|
||||
"verify-diagnostics",
|
||||
llvm::cl::desc("Check that emitted diagnostics match "
|
||||
@@ -44,11 +50,15 @@ llvm::cl::opt<bool> splitInputFile(
|
||||
// `expected-error` are produced.
|
||||
//
|
||||
// If `verifyDiagnostics` is `false`, the procedure checks if the
|
||||
// parsed module is valid.
|
||||
// parsed module is valid and if all requested transformations
|
||||
// succeeded.
|
||||
mlir::LogicalResult
|
||||
processInputBuffer(mlir::MLIRContext &context,
|
||||
std::unique_ptr<llvm::MemoryBuffer> buffer,
|
||||
llvm::raw_ostream &os, bool verifyDiagnostics) {
|
||||
llvm::raw_ostream &os, bool verifyDiagnostics,
|
||||
bool convertHLFHETensorOpsToLinalg) {
|
||||
mlir::PassManager pm(&context);
|
||||
|
||||
llvm::SourceMgr sourceMgr;
|
||||
sourceMgr.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc());
|
||||
|
||||
@@ -63,6 +73,16 @@ processInputBuffer(mlir::MLIRContext &context,
|
||||
if (!module)
|
||||
return mlir::failure();
|
||||
|
||||
if (convertHLFHETensorOpsToLinalg) {
|
||||
pm.addNestedPass<mlir::FuncOp>(
|
||||
mlir::zamalang::HLFHE::createLowerTensorOpsToLinalgPass());
|
||||
}
|
||||
|
||||
if (pm.run(*module).failed()) {
|
||||
llvm::errs() << "Could not run passes!\n";
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
module->print(os);
|
||||
|
||||
return mlir::success();
|
||||
@@ -112,14 +132,17 @@ mlir::LogicalResult compilerMain(int argc, char **argv) {
|
||||
std::move(file),
|
||||
[&](std::unique_ptr<llvm::MemoryBuffer> inputBuffer,
|
||||
llvm::raw_ostream &os) {
|
||||
return processInputBuffer(context, std::move(inputBuffer), os,
|
||||
cmdline::verifyDiagnostics);
|
||||
return processInputBuffer(
|
||||
context, std::move(inputBuffer), os,
|
||||
cmdline::verifyDiagnostics,
|
||||
cmdline::convertHLFHETensorOpsToLinalg);
|
||||
},
|
||||
output->os())))
|
||||
return mlir::failure();
|
||||
} else {
|
||||
return processInputBuffer(context, std::move(file), output->os(),
|
||||
cmdline::verifyDiagnostics);
|
||||
cmdline::verifyDiagnostics,
|
||||
cmdline::convertHLFHETensorOpsToLinalg);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user