feat(compiler): Add option --convert-hlfhe-tensor-ops-to-linalg

This adds a new command line option
`--convert-hlfhe-tensor-ops-to-linalg` that invokes a conversion pass
replacing any HLFHE tensor operation with an appropriate instance of
`linalg.generic`.
This commit is contained in:
Andi Drebes
2021-07-06 11:09:08 +02:00
committed by Quentin Bourgerie
parent 9d1cdc6a0c
commit 4504f090c5
3 changed files with 52 additions and 6 deletions

View File

@@ -5,6 +5,7 @@ target_link_libraries(zamacompiler
PRIVATE
MLIRTransforms
MidLFHEDialect
HLFHEDialect)
HLFHEDialect
HLFHEDialectTransforms)
mlir_check_all_link_libraries(zamacompiler)

View File

@@ -5,12 +5,14 @@
#include <llvm/Support/ToolOutputFile.h>
#include <mlir/Dialect/MemRef/IR/MemRef.h>
#include <mlir/Parser.h>
#include <mlir/Pass/PassManager.h>
#include <mlir/Support/FileUtilities.h>
#include <mlir/Support/LogicalResult.h>
#include <mlir/Support/ToolUtilities.h>
#include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h"
#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h"
#include "zamalang/Dialect/HLFHE/Transforms/TensorOpsToLinalg.h"
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h"
#include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h"
@@ -24,6 +26,10 @@ llvm::cl::opt<std::string> output("o",
llvm::cl::value_desc("filename"),
llvm::cl::init("-"));
llvm::cl::opt<bool> convertHLFHETensorOpsToLinalg(
"convert-hlfhe-tensor-ops-to-linalg",
llvm::cl::desc("Convert HLFHE tensor operations to linalg operations"));
llvm::cl::opt<bool> verifyDiagnostics(
"verify-diagnostics",
llvm::cl::desc("Check that emitted diagnostics match "
@@ -44,11 +50,15 @@ llvm::cl::opt<bool> splitInputFile(
// `expected-error` are produced.
//
// If `verifyDiagnostics` is `false`, the procedure checks if the
// parsed module is valid.
// parsed module is valid and if all requested transformations
// succeeded.
mlir::LogicalResult
processInputBuffer(mlir::MLIRContext &context,
std::unique_ptr<llvm::MemoryBuffer> buffer,
llvm::raw_ostream &os, bool verifyDiagnostics) {
llvm::raw_ostream &os, bool verifyDiagnostics,
bool convertHLFHETensorOpsToLinalg) {
mlir::PassManager pm(&context);
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc());
@@ -63,6 +73,16 @@ processInputBuffer(mlir::MLIRContext &context,
if (!module)
return mlir::failure();
if (convertHLFHETensorOpsToLinalg) {
pm.addNestedPass<mlir::FuncOp>(
mlir::zamalang::HLFHE::createLowerTensorOpsToLinalgPass());
}
if (pm.run(*module).failed()) {
llvm::errs() << "Could not run passes!\n";
return mlir::failure();
}
module->print(os);
return mlir::success();
@@ -112,14 +132,17 @@ mlir::LogicalResult compilerMain(int argc, char **argv) {
std::move(file),
[&](std::unique_ptr<llvm::MemoryBuffer> inputBuffer,
llvm::raw_ostream &os) {
return processInputBuffer(context, std::move(inputBuffer), os,
cmdline::verifyDiagnostics);
return processInputBuffer(
context, std::move(inputBuffer), os,
cmdline::verifyDiagnostics,
cmdline::convertHLFHETensorOpsToLinalg);
},
output->os())))
return mlir::failure();
} else {
return processInputBuffer(context, std::move(file), output->os(),
cmdline::verifyDiagnostics);
cmdline::verifyDiagnostics,
cmdline::convertHLFHETensorOpsToLinalg);
}
}

View File

@@ -0,0 +1,22 @@
// RUN: zamacompiler %s --convert-hlfhe-tensor-ops-to-linalg 2>&1 | FileCheck %s
// CHECK: #map0 = affine_map<(d0) -> (d0)>
// CHECK-NEXT: #map1 = affine_map<(d0) -> ()>
// CHECK-NEXT: module {
// CHECK-NEXT: func @dot_eint_int(%[[A0:.*]]: memref<2x!HLFHE.eint<0>>, %[[A1:.*]]: memref<2xi32>, %[[A2:.*]]: memref<!HLFHE.eint<0>>)
// CHECK-NEXT: linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types = ["reduction"]} ins(%[[A0]], %[[A1]] : memref<2x!HLFHE.eint<0>>, memref<2xi32>) outs(%arg2 : memref<!HLFHE.eint<0>>) {
// CHECK-NEXT: ^bb0(%[[A3:.*]]: !HLFHE.eint<0>, %[[A4:.*]]: i32, %[[A5:.*]]: !HLFHE.eint<0>): // no predecessors
// CHECK-NEXT: %[[T0:.*]] = "HLFHE.mul_eint_int"(%[[A3]], %[[A4]]) : (!HLFHE.eint<0>, i32) -> !HLFHE.eint<0>
// CHECK-NEXT: %[[T1:.*]] = "HLFHE.add_eint"(%[[T0]], %[[A5]]) : (!HLFHE.eint<0>, !HLFHE.eint<0>) -> !HLFHE.eint<0>
// CHECK-NEXT: linalg.yield %[[T1]] : !HLFHE.eint<0>
// CHECK-NEXT: }
// CHECK-NEXT: return
// CHECK-NEXT: }
func @dot_eint_int(%arg0: memref<2x!HLFHE.eint<0>>,
%arg1: memref<2xi32>,
%arg2: memref<!HLFHE.eint<0>>)
{
"HLFHE.dot_eint_int"(%arg0, %arg1, %arg2) :
(memref<2x!HLFHE.eint<0>>, memref<2xi32>, memref<!HLFHE.eint<0>>) -> ()
return
}