feat/refactor(compiler): Add --passes options to activate only a subset of passes (#57)

This commit is contained in:
Quentin Bourgerie
2021-07-22 19:18:30 +02:00
parent 1605551f1a
commit 4e6579e019
25 changed files with 100 additions and 101 deletions

View File

@@ -2,6 +2,5 @@ set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Conversion)
add_public_tablegen_target(MLIRConversionPassIncGen)
add_mlir_doc(Passes ConversionPasses ./ -gen-pass-doc)
add_subdirectory(HLFHEToMidLFHE)
add_subdirectory(HLFHEToMidLFHE)

View File

@@ -0,0 +1,15 @@
#ifndef ZAMALANG_CONVERSION_HLFHETENSOROPSTOLINALG_PASS_H_
#define ZAMALANG_CONVERSION_HLFHETENSOROPSTOLINALG_PASS_H_
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace zamalang {
/// Create a pass to convert `HLFHE` tensor operators to linal.generic
/// operators.
std::unique_ptr<mlir::FunctionPass> createConvertHLFHETensorOpsToLinalg();
} // namespace zamalang
} // namespace mlir
#endif

View File

@@ -1,26 +1,14 @@
//===- Passes.h - Conversion Pass Construction and Registration -----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef ZAMALANG_CONVERSION_PASSES_H
#define ZAMALANG_CONVERSION_PASSES_H
#ifndef ZAMALANG_TRANSFORMS_PASSES_H
#define ZAMALANG_TRANSFORMS_PASSES_H
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "zamalang/Conversion/HLFHEToMidLFHE/Pass.h"
#include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h"
// namespace mlir {
// namespace zamalang {
#define GEN_PASS_CLASSES
#include "zamalang/Conversion/Passes.h.inc"
// } // namespace zamalang
// } // namespace mlir
#endif

View File

@@ -1 +1 @@
add_subdirectory(IR)
add_subdirectory(IR)

View File

@@ -1,14 +0,0 @@
#ifndef ZAMALANG_DIALECT_HLFHE_TRANSFORMS_TENSOROPSTOLINALG_H
#define ZAMALANG_DIALECT_HLFHE_TRANSFORMS_TENSOROPSTOLINALG_H
#include <mlir/Pass/Pass.h>
namespace mlir {
namespace zamalang {
namespace HLFHE {
std::unique_ptr<mlir::Pass> createLowerTensorOpsToLinalgPass();
} // namespace HLFHE
} // namespace zamalang
} // namespace mlir
#endif

View File

@@ -1 +1,2 @@
add_subdirectory(HLFHEToMidLFHE)
add_subdirectory(HLFHEToMidLFHE)
add_subdirectory(HLFHETensorOpsToLinalg)

View File

@@ -1,4 +1,4 @@
add_mlir_dialect_library(HLFHEDialectTransforms
add_mlir_dialect_library(HLFHETensorOpsToLinalg
TensorOpsToLinalg.cpp
ADDITIONAL_HEADER_DIRS
@@ -6,6 +6,7 @@ add_mlir_dialect_library(HLFHEDialectTransforms
DEPENDS
HLFHEDialect
MLIRConversionPassIncGen
LINK_LIBS PUBLIC
MLIRIR

View File

@@ -1,14 +1,16 @@
#include "zamalang/Dialect/HLFHE/Transforms/TensorOpsToLinalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/OperationSupport.h"
#include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h"
#include <mlir/Dialect/Linalg/IR/LinalgOps.h>
#include <mlir/Dialect/MemRef/IR/MemRef.h>
#include <mlir/IR/PatternMatch.h>
#include <mlir/Pass/Pass.h>
#include <mlir/Transforms/DialectConversion.h>
#include <zamalang/Dialect/HLFHE/IR/HLFHEOps.h>
#include <iostream>
#include "zamalang/Conversion/Passes.h"
#include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h"
#include "zamalang/Dialect/HLFHE/IR/HLFHEOps.h"
struct DotToLinalgGeneric : public ::mlir::RewritePattern {
DotToLinalgGeneric(::mlir::MLIRContext *context)
@@ -88,15 +90,13 @@ struct DotToLinalgGeneric : public ::mlir::RewritePattern {
};
namespace {
struct LowerTensorOpsToLinalgPass
: public mlir::PassWrapper<LowerTensorOpsToLinalgPass, mlir::FunctionPass> {
void getDependentDialects(mlir::DialectRegistry &registry) const override {
registry.insert<mlir::linalg::LinalgDialect>();
}
struct HLFHETensorOpsToLinalg
: public HLFHETensorOpsToLinalgBase<HLFHETensorOpsToLinalg> {
void runOnFunction() final;
};
void LowerTensorOpsToLinalgPass::runOnFunction() {
void HLFHETensorOpsToLinalg::runOnFunction() {
mlir::FuncOp function = this->getFunction();
mlir::ConversionTarget target(getContext());
@@ -119,10 +119,8 @@ void LowerTensorOpsToLinalgPass::runOnFunction() {
namespace mlir {
namespace zamalang {
namespace HLFHE {
std::unique_ptr<mlir::Pass> createLowerTensorOpsToLinalgPass() {
return std::make_unique<LowerTensorOpsToLinalgPass>();
std::unique_ptr<mlir::Pass> createConvertHLFHETensorOpsToLinalg() {
return std::make_unique<HLFHETensorOpsToLinalg>();
}
} // namespace HLFHE
} // namespace zamalang
} // namespace mlir

View File

@@ -1,2 +1 @@
add_subdirectory(IR)
add_subdirectory(Transforms)
add_subdirectory(IR)

View File

@@ -6,7 +6,8 @@ target_link_libraries(zamacompiler
MLIRTransforms
MidLFHEDialect
HLFHEDialect
HLFHEDialectTransforms
HLFHETensorOpsToLinalg
HLFHEToMidLFHE
)

View File

@@ -3,6 +3,7 @@
#include <llvm/Support/CommandLine.h>
#include <llvm/Support/SourceMgr.h>
#include <llvm/Support/ToolOutputFile.h>
#include <mlir/Dialect/Linalg/IR/LinalgOps.h>
#include <mlir/Dialect/MemRef/IR/MemRef.h>
#include <mlir/Parser.h>
#include <mlir/Pass/PassManager.h>
@@ -10,14 +11,15 @@
#include <mlir/Support/LogicalResult.h>
#include <mlir/Support/ToolUtilities.h>
#include "zamalang/Conversion/HLFHETensorOpsToLinalg/Pass.h"
#include "zamalang/Conversion/HLFHEToMidLFHE/Pass.h"
#include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h"
#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h"
#include "zamalang/Dialect/HLFHE/Transforms/TensorOpsToLinalg.h"
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h"
#include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h"
namespace cmdline {
llvm::cl::list<std::string> inputs(llvm::cl::Positional,
llvm::cl::desc("<Input files>"),
llvm::cl::OneOrMore);
@@ -27,13 +29,14 @@ llvm::cl::opt<std::string> output("o",
llvm::cl::value_desc("filename"),
llvm::cl::init("-"));
llvm::cl::opt<bool> convertHLFHETensorOpsToLinalg(
"convert-hlfhe-tensor-ops-to-linalg",
llvm::cl::desc("Convert HLFHE tensor operations to linalg operations"));
llvm::cl::list<std::string> passes(
"passes",
llvm::cl::desc("Specify the passes to run (use only for compiler tests)"),
llvm::cl::value_desc("passname"), llvm::cl::ZeroOrMore);
llvm::cl::opt<bool> convertHLFHEToMidLFHE(
"convert-hlfhe-to-midlfhe",
llvm::cl::desc("Convert HLFHE operations to MidLFHE operations"));
llvm::cl::opt<bool> roundTrip("round-trip",
llvm::cl::desc("Just parse and dump"),
llvm::cl::init(false));
llvm::cl::opt<bool> verifyDiagnostics(
"verify-diagnostics",
@@ -48,6 +51,24 @@ llvm::cl::opt<bool> splitInputFile(
llvm::cl::init(false));
}; // namespace cmdline
void addPassCmdLineFiltered(mlir::PassManager &pm,
std::unique_ptr<mlir::Pass> pass) {
if (cmdline::roundTrip)
return;
auto passName = pass->getName();
if (cmdline::passes.size() == 0 ||
std::any_of(
cmdline::passes.begin(), cmdline::passes.end(),
[&](const std::string &p) { return pass->getArgument() == p; })) {
if (*pass->getOpName() == "module") {
pm.addPass(std::move(pass));
} else {
pm.nest(*pass->getOpName()).addPass(std::move(pass));
}
}
return;
}
// Process a single source buffer
//
// If `verifyDiagnostics` is `true`, the procedure only checks if the
@@ -57,10 +78,10 @@ llvm::cl::opt<bool> splitInputFile(
// If `verifyDiagnostics` is `false`, the procedure checks if the
// parsed module is valid and if all requested transformations
// succeeded.
mlir::LogicalResult processInputBuffer(
mlir::MLIRContext &context, std::unique_ptr<llvm::MemoryBuffer> buffer,
llvm::raw_ostream &os, bool verifyDiagnostics,
bool convertHLFHETensorOpsToLinalg, bool convertHLFHEToMidLFHE) {
mlir::LogicalResult
processInputBuffer(mlir::MLIRContext &context,
std::unique_ptr<llvm::MemoryBuffer> buffer,
llvm::raw_ostream &os, bool verifyDiagnostics) {
mlir::PassManager pm(&context);
llvm::SourceMgr sourceMgr;
@@ -77,15 +98,9 @@ mlir::LogicalResult processInputBuffer(
if (!module)
return mlir::failure();
if (convertHLFHETensorOpsToLinalg) {
pm.addNestedPass<mlir::FuncOp>(
mlir::zamalang::HLFHE::createLowerTensorOpsToLinalgPass());
}
if (convertHLFHEToMidLFHE) {
pm.addNestedPass<mlir::FuncOp>(
mlir::zamalang::createConvertHLFHEToMidLFHEPass());
}
addPassCmdLineFiltered(pm,
mlir::zamalang::createConvertHLFHETensorOpsToLinalg());
addPassCmdLineFiltered(pm, mlir::zamalang::createConvertHLFHEToMidLFHEPass());
if (pm.run(*module).failed()) {
llvm::errs() << "Could not run passes!\n";
@@ -112,6 +127,7 @@ mlir::LogicalResult compilerMain(int argc, char **argv) {
context.getOrLoadDialect<mlir::zamalang::MidLFHE::MidLFHEDialect>();
context.getOrLoadDialect<mlir::StandardOpsDialect>();
context.getOrLoadDialect<mlir::memref::MemRefDialect>();
context.getOrLoadDialect<mlir::linalg::LinalgDialect>();
if (cmdline::verifyDiagnostics)
context.printOpOnDiagnostic(false);
@@ -141,19 +157,14 @@ mlir::LogicalResult compilerMain(int argc, char **argv) {
std::move(file),
[&](std::unique_ptr<llvm::MemoryBuffer> inputBuffer,
llvm::raw_ostream &os) {
return processInputBuffer(
context, std::move(inputBuffer), os,
cmdline::verifyDiagnostics,
cmdline::convertHLFHETensorOpsToLinalg,
cmdline::convertHLFHEToMidLFHE);
return processInputBuffer(context, std::move(inputBuffer), os,
cmdline::verifyDiagnostics);
},
output->os())))
return mlir::failure();
} else {
return processInputBuffer(context, std::move(file), output->os(),
cmdline::verifyDiagnostics,
cmdline::convertHLFHETensorOpsToLinalg,
cmdline::convertHLFHEToMidLFHE);
cmdline::verifyDiagnostics);
}
}

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler %s --convert-hlfhe-to-midlfhe 2>&1| FileCheck %s
// RUN: zamacompiler %s --passes hlfhe-to-midlfhe 2>&1| FileCheck %s
// CHECK-LABEL: func @add_eint(%arg0: !MidLFHE.glwe<{_,_,_}{7}>, %arg1: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}>
func @add_eint(%arg0: !HLFHE.eint<7>, %arg1: !HLFHE.eint<7>) -> !HLFHE.eint<7> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler %s --convert-hlfhe-to-midlfhe 2>&1| FileCheck %s
// RUN: zamacompiler %s --passes hlfhe-to-midlfhe 2>&1| FileCheck %s
// CHECK-LABEL: func @add_eint_int(%arg0: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}>
func @add_eint_int(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler %s --convert-hlfhe-to-midlfhe 2>&1| FileCheck %s
// RUN: zamacompiler %s --passes hlfhe-to-midlfhe 2>&1| FileCheck %s
// CHECK-LABEL: func @apply_lookup_table(%arg0: !MidLFHE.glwe<{_,_,_}{2}>, %arg1: memref<4xi2>) -> !MidLFHE.glwe<{_,_,_}{2}>
func @apply_lookup_table(%arg0: !HLFHE.eint<2>, %arg1: memref<4xi2>) -> !HLFHE.eint<2> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler %s --convert-hlfhe-to-midlfhe 2>&1| FileCheck %s
// RUN: zamacompiler %s --passes hlfhe-to-midlfhe 2>&1| FileCheck %s
// CHECK-LABEL: func @mul_eint_int(%arg0: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}>
func @mul_eint_int(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler %s --convert-hlfhe-to-midlfhe 2>&1| FileCheck %s
// RUN: zamacompiler %s --passes hlfhe-to-midlfhe 2>&1| FileCheck %s
// CHECK-LABEL: func @sub_int_eint(%arg0: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}>
func @sub_int_eint(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler %s 2>&1| FileCheck %s
// RUN: zamacompiler --round-trip %s 2>&1| FileCheck %s
// CHECK-LABEL: func @add_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2>
func @add_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> {
@@ -51,14 +51,14 @@ func @apply_lookup_table(%arg0: !HLFHE.eint<2>, %arg1: memref<4xi2>) -> !HLFHE.e
return %1: !HLFHE.eint<2>
}
// CHECK-LABEL: func @dot_eint_int(%arg0: memref<2x!HLFHE.eint<2>>, %arg1: memref<2xi32>, %arg2: memref<!HLFHE.eint<2>>)
// CHECK-LABEL: func @dot_eint_int(%arg0: memref<2x!HLFHE.eint<2>>, %arg1: memref<2xi3>, %arg2: memref<!HLFHE.eint<2>>)
func @dot_eint_int(%arg0: memref<2x!HLFHE.eint<2>>,
%arg1: memref<2xi32>,
%arg1: memref<2xi3>,
%arg2: memref<!HLFHE.eint<2>>)
{
// CHECK-NEXT: "HLFHE.dot_eint_int"(%arg0, %arg1, %arg2) : (memref<2x!HLFHE.eint<2>>, memref<2xi32>, memref<!HLFHE.eint<2>>) -> ()
// CHECK-NEXT: "HLFHE.dot_eint_int"(%arg0, %arg1, %arg2) : (memref<2x!HLFHE.eint<2>>, memref<2xi3>, memref<!HLFHE.eint<2>>) -> ()
"HLFHE.dot_eint_int"(%arg0, %arg1, %arg2) :
(memref<2x!HLFHE.eint<2>>, memref<2xi32>, memref<!HLFHE.eint<2>>) -> ()
(memref<2x!HLFHE.eint<2>>, memref<2xi3>, memref<!HLFHE.eint<2>>) -> ()
//CHECK-NEXT: return
return

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler %s --convert-hlfhe-tensor-ops-to-linalg 2>&1 | FileCheck %s
// RUN: zamacompiler %s --passes hlfhe-tensor-ops-to-linalg 2>&1 | FileCheck %s
// CHECK: #map0 = affine_map<(d0) -> (d0)>
// CHECK-NEXT: #map1 = affine_map<(d0) -> ()>

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler %s 2>&1| FileCheck %s
// RUN: zamacompiler --round-trip %s 2>&1| FileCheck %s
// CHECK-LABEL: func @memref_arg(%arg0: memref<2x!HLFHE.eint<7>>
func @memref_arg(%arg0: memref<2x!HLFHE.eint<7>>) {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler %s 2>&1| FileCheck %s
// RUN: zamacompiler --round-trip %s 2>&1| FileCheck %s
// CHECK-LABEL: func @add_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}>
func @add_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler %s 2>&1| FileCheck %s
// RUN: zamacompiler --round-trip %s 2>&1| FileCheck %s
// CHECK-LABEL: func @add_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}>
func @add_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler %s 2>&1| FileCheck %s
// RUN: zamacompiler --round-trip %s 2>&1| FileCheck %s
// CHECK-LABEL: func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: memref<128xi2>) -> !MidLFHE.glwe<{512,10,64}{2}>
func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: memref<128xi2>) -> !MidLFHE.glwe<{512,10,64}{2}> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler %s 2>&1| FileCheck %s
// RUN: zamacompiler --round-trip %s 2>&1| FileCheck %s
// CHECK-LABEL: func @mul_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}>
func @mul_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler %s 2>&1| FileCheck %s
// RUN: zamacompiler --round-trip %s 2>&1| FileCheck %s
// CHECK-LABEL: func @sub_int_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}>
func @sub_int_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler %s 2>&1| FileCheck %s
// RUN: zamacompiler %s --round-trip 2>&1| FileCheck %s
// CHECK-LABEL: func @glwe_0(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}>
func @glwe_0(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> {