diff --git a/compiler/src/CMakeLists.txt b/compiler/src/CMakeLists.txt index cfeab8508..4d6c02417 100644 --- a/compiler/src/CMakeLists.txt +++ b/compiler/src/CMakeLists.txt @@ -5,6 +5,7 @@ target_link_libraries(zamacompiler PRIVATE MLIRTransforms MidLFHEDialect - HLFHEDialect) + HLFHEDialect + HLFHEDialectTransforms) mlir_check_all_link_libraries(zamacompiler) diff --git a/compiler/src/main.cpp b/compiler/src/main.cpp index 3a338aa88..53e8cf21a 100644 --- a/compiler/src/main.cpp +++ b/compiler/src/main.cpp @@ -5,12 +5,14 @@ #include #include #include +#include #include #include #include #include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h" #include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h" +#include "zamalang/Dialect/HLFHE/Transforms/TensorOpsToLinalg.h" #include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h" #include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h" @@ -24,6 +26,10 @@ llvm::cl::opt output("o", llvm::cl::value_desc("filename"), llvm::cl::init("-")); +llvm::cl::opt convertHLFHETensorOpsToLinalg( + "convert-hlfhe-tensor-ops-to-linalg", + llvm::cl::desc("Convert HLFHE tensor operations to linalg operations")); + llvm::cl::opt verifyDiagnostics( "verify-diagnostics", llvm::cl::desc("Check that emitted diagnostics match " @@ -44,11 +50,15 @@ llvm::cl::opt splitInputFile( // `expected-error` are produced. // // If `verifyDiagnostics` is `false`, the procedure checks if the -// parsed module is valid. +// parsed module is valid and if all requested transformations +// succeeded. mlir::LogicalResult processInputBuffer(mlir::MLIRContext &context, std::unique_ptr buffer, - llvm::raw_ostream &os, bool verifyDiagnostics) { + llvm::raw_ostream &os, bool verifyDiagnostics, + bool convertHLFHETensorOpsToLinalg) { + mlir::PassManager pm(&context); + llvm::SourceMgr sourceMgr; sourceMgr.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc()); @@ -63,6 +73,16 @@ processInputBuffer(mlir::MLIRContext &context, if (!module) return mlir::failure(); + if (convertHLFHETensorOpsToLinalg) { + pm.addNestedPass( + mlir::zamalang::HLFHE::createLowerTensorOpsToLinalgPass()); + } + + if (pm.run(*module).failed()) { + llvm::errs() << "Could not run passes!\n"; + return mlir::failure(); + } + module->print(os); return mlir::success(); @@ -112,14 +132,17 @@ mlir::LogicalResult compilerMain(int argc, char **argv) { std::move(file), [&](std::unique_ptr inputBuffer, llvm::raw_ostream &os) { - return processInputBuffer(context, std::move(inputBuffer), os, - cmdline::verifyDiagnostics); + return processInputBuffer( + context, std::move(inputBuffer), os, + cmdline::verifyDiagnostics, + cmdline::convertHLFHETensorOpsToLinalg); }, output->os()))) return mlir::failure(); } else { return processInputBuffer(context, std::move(file), output->os(), - cmdline::verifyDiagnostics); + cmdline::verifyDiagnostics, + cmdline::convertHLFHETensorOpsToLinalg); } } diff --git a/compiler/tests/Dialect/HLFHE/tensor-ops-to-linalg.mlir b/compiler/tests/Dialect/HLFHE/tensor-ops-to-linalg.mlir new file mode 100644 index 000000000..f117126f2 --- /dev/null +++ b/compiler/tests/Dialect/HLFHE/tensor-ops-to-linalg.mlir @@ -0,0 +1,22 @@ +// RUN: zamacompiler %s --convert-hlfhe-tensor-ops-to-linalg 2>&1 | FileCheck %s + +// CHECK: #map0 = affine_map<(d0) -> (d0)> +// CHECK-NEXT: #map1 = affine_map<(d0) -> ()> +// CHECK-NEXT: module { +// CHECK-NEXT: func @dot_eint_int(%[[A0:.*]]: memref<2x!HLFHE.eint<0>>, %[[A1:.*]]: memref<2xi32>, %[[A2:.*]]: memref>) +// CHECK-NEXT: linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types = ["reduction"]} ins(%[[A0]], %[[A1]] : memref<2x!HLFHE.eint<0>>, memref<2xi32>) outs(%arg2 : memref>) { +// CHECK-NEXT: ^bb0(%[[A3:.*]]: !HLFHE.eint<0>, %[[A4:.*]]: i32, %[[A5:.*]]: !HLFHE.eint<0>): // no predecessors +// CHECK-NEXT: %[[T0:.*]] = "HLFHE.mul_eint_int"(%[[A3]], %[[A4]]) : (!HLFHE.eint<0>, i32) -> !HLFHE.eint<0> +// CHECK-NEXT: %[[T1:.*]] = "HLFHE.add_eint"(%[[T0]], %[[A5]]) : (!HLFHE.eint<0>, !HLFHE.eint<0>) -> !HLFHE.eint<0> +// CHECK-NEXT: linalg.yield %[[T1]] : !HLFHE.eint<0> +// CHECK-NEXT: } +// CHECK-NEXT: return +// CHECK-NEXT: } +func @dot_eint_int(%arg0: memref<2x!HLFHE.eint<0>>, + %arg1: memref<2xi32>, + %arg2: memref>) +{ + "HLFHE.dot_eint_int"(%arg0, %arg1, %arg2) : + (memref<2x!HLFHE.eint<0>>, memref<2xi32>, memref>) -> () + return +}