mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
feat(compiler): Add option --convert-hlfhe-tensor-ops-to-linalg
This adds a new command line option `--convert-hlfhe-tensor-ops-to-linalg` that invokes a conversion pass replacing any HLFHE tensor operation with an appropriate instance of `linalg.generic`.
This commit is contained in:
committed by
Quentin Bourgerie
parent
9d1cdc6a0c
commit
4504f090c5
@@ -5,6 +5,7 @@ target_link_libraries(zamacompiler
|
||||
PRIVATE
|
||||
MLIRTransforms
|
||||
MidLFHEDialect
|
||||
HLFHEDialect)
|
||||
HLFHEDialect
|
||||
HLFHEDialectTransforms)
|
||||
|
||||
mlir_check_all_link_libraries(zamacompiler)
|
||||
|
||||
@@ -5,12 +5,14 @@
|
||||
#include <llvm/Support/ToolOutputFile.h>
|
||||
#include <mlir/Dialect/MemRef/IR/MemRef.h>
|
||||
#include <mlir/Parser.h>
|
||||
#include <mlir/Pass/PassManager.h>
|
||||
#include <mlir/Support/FileUtilities.h>
|
||||
#include <mlir/Support/LogicalResult.h>
|
||||
#include <mlir/Support/ToolUtilities.h>
|
||||
|
||||
#include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h"
|
||||
#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h"
|
||||
#include "zamalang/Dialect/HLFHE/Transforms/TensorOpsToLinalg.h"
|
||||
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h"
|
||||
#include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h"
|
||||
|
||||
@@ -24,6 +26,10 @@ llvm::cl::opt<std::string> output("o",
|
||||
llvm::cl::value_desc("filename"),
|
||||
llvm::cl::init("-"));
|
||||
|
||||
llvm::cl::opt<bool> convertHLFHETensorOpsToLinalg(
|
||||
"convert-hlfhe-tensor-ops-to-linalg",
|
||||
llvm::cl::desc("Convert HLFHE tensor operations to linalg operations"));
|
||||
|
||||
llvm::cl::opt<bool> verifyDiagnostics(
|
||||
"verify-diagnostics",
|
||||
llvm::cl::desc("Check that emitted diagnostics match "
|
||||
@@ -44,11 +50,15 @@ llvm::cl::opt<bool> splitInputFile(
|
||||
// `expected-error` are produced.
|
||||
//
|
||||
// If `verifyDiagnostics` is `false`, the procedure checks if the
|
||||
// parsed module is valid.
|
||||
// parsed module is valid and if all requested transformations
|
||||
// succeeded.
|
||||
mlir::LogicalResult
|
||||
processInputBuffer(mlir::MLIRContext &context,
|
||||
std::unique_ptr<llvm::MemoryBuffer> buffer,
|
||||
llvm::raw_ostream &os, bool verifyDiagnostics) {
|
||||
llvm::raw_ostream &os, bool verifyDiagnostics,
|
||||
bool convertHLFHETensorOpsToLinalg) {
|
||||
mlir::PassManager pm(&context);
|
||||
|
||||
llvm::SourceMgr sourceMgr;
|
||||
sourceMgr.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc());
|
||||
|
||||
@@ -63,6 +73,16 @@ processInputBuffer(mlir::MLIRContext &context,
|
||||
if (!module)
|
||||
return mlir::failure();
|
||||
|
||||
if (convertHLFHETensorOpsToLinalg) {
|
||||
pm.addNestedPass<mlir::FuncOp>(
|
||||
mlir::zamalang::HLFHE::createLowerTensorOpsToLinalgPass());
|
||||
}
|
||||
|
||||
if (pm.run(*module).failed()) {
|
||||
llvm::errs() << "Could not run passes!\n";
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
module->print(os);
|
||||
|
||||
return mlir::success();
|
||||
@@ -112,14 +132,17 @@ mlir::LogicalResult compilerMain(int argc, char **argv) {
|
||||
std::move(file),
|
||||
[&](std::unique_ptr<llvm::MemoryBuffer> inputBuffer,
|
||||
llvm::raw_ostream &os) {
|
||||
return processInputBuffer(context, std::move(inputBuffer), os,
|
||||
cmdline::verifyDiagnostics);
|
||||
return processInputBuffer(
|
||||
context, std::move(inputBuffer), os,
|
||||
cmdline::verifyDiagnostics,
|
||||
cmdline::convertHLFHETensorOpsToLinalg);
|
||||
},
|
||||
output->os())))
|
||||
return mlir::failure();
|
||||
} else {
|
||||
return processInputBuffer(context, std::move(file), output->os(),
|
||||
cmdline::verifyDiagnostics);
|
||||
cmdline::verifyDiagnostics,
|
||||
cmdline::convertHLFHETensorOpsToLinalg);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
22
compiler/tests/Dialect/HLFHE/tensor-ops-to-linalg.mlir
Normal file
22
compiler/tests/Dialect/HLFHE/tensor-ops-to-linalg.mlir
Normal file
@@ -0,0 +1,22 @@
|
||||
// RUN: zamacompiler %s --convert-hlfhe-tensor-ops-to-linalg 2>&1 | FileCheck %s
|
||||
|
||||
// CHECK: #map0 = affine_map<(d0) -> (d0)>
|
||||
// CHECK-NEXT: #map1 = affine_map<(d0) -> ()>
|
||||
// CHECK-NEXT: module {
|
||||
// CHECK-NEXT: func @dot_eint_int(%[[A0:.*]]: memref<2x!HLFHE.eint<0>>, %[[A1:.*]]: memref<2xi32>, %[[A2:.*]]: memref<!HLFHE.eint<0>>)
|
||||
// CHECK-NEXT: linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types = ["reduction"]} ins(%[[A0]], %[[A1]] : memref<2x!HLFHE.eint<0>>, memref<2xi32>) outs(%arg2 : memref<!HLFHE.eint<0>>) {
|
||||
// CHECK-NEXT: ^bb0(%[[A3:.*]]: !HLFHE.eint<0>, %[[A4:.*]]: i32, %[[A5:.*]]: !HLFHE.eint<0>): // no predecessors
|
||||
// CHECK-NEXT: %[[T0:.*]] = "HLFHE.mul_eint_int"(%[[A3]], %[[A4]]) : (!HLFHE.eint<0>, i32) -> !HLFHE.eint<0>
|
||||
// CHECK-NEXT: %[[T1:.*]] = "HLFHE.add_eint"(%[[T0]], %[[A5]]) : (!HLFHE.eint<0>, !HLFHE.eint<0>) -> !HLFHE.eint<0>
|
||||
// CHECK-NEXT: linalg.yield %[[T1]] : !HLFHE.eint<0>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return
|
||||
// CHECK-NEXT: }
|
||||
func @dot_eint_int(%arg0: memref<2x!HLFHE.eint<0>>,
|
||||
%arg1: memref<2xi32>,
|
||||
%arg2: memref<!HLFHE.eint<0>>)
|
||||
{
|
||||
"HLFHE.dot_eint_int"(%arg0, %arg1, %arg2) :
|
||||
(memref<2x!HLFHE.eint<0>>, memref<2xi32>, memref<!HLFHE.eint<0>>) -> ()
|
||||
return
|
||||
}
|
||||
Reference in New Issue
Block a user