From cdca7ca6f723623138662b3549682d7a13e3fe2a Mon Sep 17 00:00:00 2001 From: Antoniu Pop Date: Mon, 20 Dec 2021 12:20:33 +0000 Subject: [PATCH] feat(compiler): add Dataflow/RT dialect and code generation for dataflow auto parallelization. --- .../zamalang/Dialect/HLFHE/IR/HLFHEOps.td | 2 +- .../zamalang/Dialect/RT/Analysis/Autopar.h | 28 ++ .../zamalang/Dialect/RT/Analysis/Autopar.td | 87 +++++ .../Dialect/RT/Analysis/CMakeLists.txt | 6 + .../zamalang/Dialect/RT/CMakeLists.txt | 1 + .../include/zamalang/Dialect/RT/IR/RTOps.td | 82 ++++- .../include/zamalang/Dialect/RT/IR/RTTypes.td | 35 +- .../include/zamalang/Support/CompilerEngine.h | 4 +- compiler/include/zamalang/Support/Pipeline.h | 3 + .../HLFHEToMidLFHE/HLFHEToMidLFHE.cpp | 7 + .../LowLFHEUnparametrize.cpp | 7 + .../MLIRLowerableDialectsToLLVM.cpp | 23 +- .../MidLFHEGlobalParametrization.cpp | 7 + .../MidLFHEToLowLFHE/MidLFHEToLowLFHE.cpp | 7 + .../RT/Analysis/BufferizeDataflowTaskOps.cpp | 119 +++++++ .../RT/Analysis/BuildDataflowTaskGraph.cpp | 266 ++++++++++++++ .../lib/Dialect/RT/Analysis/CMakeLists.txt | 18 + .../RT/Analysis/LowerDataflowTasksToRT.cpp | 337 ++++++++++++++++++ ...owerRTToLLVMDFRCallsConversionPatterns.cpp | 310 ++++++++++++++++ compiler/lib/Dialect/RT/CMakeLists.txt | 1 + compiler/lib/Dialect/RT/IR/RTDialect.cpp | 17 +- compiler/lib/Dialect/RT/IR/RTOps.cpp | 18 + compiler/lib/Support/CMakeLists.txt | 1 + compiler/lib/Support/CompilerEngine.cpp | 9 + compiler/lib/Support/Pipeline.cpp | 22 ++ compiler/src/main.cpp | 11 +- 26 files changed, 1418 insertions(+), 10 deletions(-) create mode 100644 compiler/include/zamalang/Dialect/RT/Analysis/Autopar.h create mode 100644 compiler/include/zamalang/Dialect/RT/Analysis/Autopar.td create mode 100644 compiler/include/zamalang/Dialect/RT/Analysis/CMakeLists.txt create mode 100644 compiler/lib/Dialect/RT/Analysis/BufferizeDataflowTaskOps.cpp create mode 100644 compiler/lib/Dialect/RT/Analysis/BuildDataflowTaskGraph.cpp create mode 100644 compiler/lib/Dialect/RT/Analysis/CMakeLists.txt create mode 100644 compiler/lib/Dialect/RT/Analysis/LowerDataflowTasksToRT.cpp create mode 100644 compiler/lib/Dialect/RT/Analysis/LowerRTToLLVMDFRCallsConversionPatterns.cpp diff --git a/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHEOps.td b/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHEOps.td index c893d7657..0287074a2 100644 --- a/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHEOps.td +++ b/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHEOps.td @@ -19,7 +19,7 @@ class HLFHE_Op traits = []> : Op; // Generates an encrypted zero constant -def ZeroEintOp : HLFHE_Op<"zero"> { +def ZeroEintOp : HLFHE_Op<"zero", [NoSideEffect]> { let arguments = (ins); let results = (outs EncryptedIntegerType:$out); } diff --git a/compiler/include/zamalang/Dialect/RT/Analysis/Autopar.h b/compiler/include/zamalang/Dialect/RT/Analysis/Autopar.h new file mode 100644 index 000000000..a7609a7c2 --- /dev/null +++ b/compiler/include/zamalang/Dialect/RT/Analysis/Autopar.h @@ -0,0 +1,28 @@ +#ifndef ZAMALANG_DIALECT_RT_ANALYSIS_AUTOPAR_H +#define ZAMALANG_DIALECT_RT_ANALYSIS_AUTOPAR_H + +#include +#include +#include + +namespace mlir { + +class LLVMTypeConverter; +class BufferizeTypeConverter; +class RewritePatternSet; + +namespace zamalang { +std::unique_ptr +createBuildDataflowTaskGraphPass(bool debug = false); +std::unique_ptr createLowerDataflowTasksPass(bool debug = false); +std::unique_ptr +createBufferizeDataflowTaskOpsPass(bool debug = false); +std::unique_ptr createFixupDataflowTaskOpsPass(bool debug = false); +void populateRTToLLVMConversionPatterns(mlir::LLVMTypeConverter &converter, + mlir::RewritePatternSet &patterns); +void populateRTBufferizePatterns(mlir::BufferizeTypeConverter &typeConverter, + mlir::RewritePatternSet &patterns); +} // namespace zamalang +} // namespace mlir + +#endif diff --git a/compiler/include/zamalang/Dialect/RT/Analysis/Autopar.td b/compiler/include/zamalang/Dialect/RT/Analysis/Autopar.td new file mode 100644 index 000000000..27ccd327f --- /dev/null +++ b/compiler/include/zamalang/Dialect/RT/Analysis/Autopar.td @@ -0,0 +1,87 @@ +#ifndef ZAMALANG_DIALECT_RT_ANALYSIS_AUTOPAR +#define ZAMALANG_DIALECT_RT_ANALYSIS_AUTOPAR + +include "mlir/Pass/PassBase.td" + +def BuildDataflowTaskGraph : Pass<"BuildDataflowTaskGraph", "mlir::ModuleOp"> { + let summary = + "Identify profitable dataflow tasks and build DataflowTaskGraph."; + + let description = [{ + This pass builds a dataflow graph out of a HLFHE program. + + In its current incarnation, it considers some heavier weight + operations (e.g., HLFHELinalg Dot and Matmult or bootstraps) as + candidates for being executed in a discrete task, and then + sinks within the task the lighter weight operation that do not + increase the graph cut (amount of dependences in or out). + + The output is a program partitioned in RT::DataflowTaskOp that + expose task dependences as arguments and results of the + DataflowTaskOp. + + Example: + +```mlir + func @main(%arg0: tensor<3x4x!HLFHE.eint<2>>, %arg1: tensor<4x2xi3>) -> tensor<3x2x!HLFHE.eint<2>> { + %0 = "HLFHELinalg.matmul_eint_int"(%arg0, %arg1) : (tensor<3x4x!HLFHE.eint<2>>, tensor<4x2xi3>) -> tensor<3x2x!HLFHE.eint<2>> + return %0 : tensor<3x2x!HLFHE.eint<2>> + } +``` + + Will result in generating a dataflow task for the Matmul operation: + +```mlir + func @main(%arg0: tensor<3x4x!HLFHE.eint<2>>, %arg1: tensor<4x2xi3>) -> tensor<3x2x!HLFHE.eint<2>> { + %0 = "RT.dataflow_task"(%arg0, %arg1) ( { + %1 = "HLFHELinalg.matmul_eint_int"(%arg0, %arg1) : (tensor<3x4x!HLFHE.eint<2>>, tensor<4x2xi3>) -> tensor<3x2x!HLFHE.eint<2>> + "RT.dataflow_yield"(%1) : (tensor<3x2x!HLFHE.eint<2>>) -> () + }) : (tensor<3x4x!HLFHE.eint<2>>, tensor<4x2xi3>) -> tensor<3x2x!HLFHE.eint<2>> + return %0 : tensor<3x2x!HLFHE.eint<2>> + } +``` + }]; +} + +def BufferizeDataflowTaskOps : Pass<"BufferizeDataflowTaskOps", "mlir::ModuleOp"> { + let summary = + "Bufferize DataflowTaskOp(s)."; + + let description = [{ + This pass lowers DataflowTaskOp arguments and results from tensors + to mlir::memref. It also lowers the arguments of DataflowYieldOp. + }]; +} + +def FixupDataflowTaskOps : Pass<"FixupDataflowTaskOps", "mlir::ModuleOp"> { + let summary = + "Fix DataflowTaskOp(s) before lowering."; + + let description = [{ + This pass fixes up code changes that intervene between the + BuildDataflowTaskGraph pass and the lowering of the taskgraph to + LLVMIR and calls to the DFR runtime system. + + In particular, some operations (e.g., constants, dimension + operations, etc.) can be used within the task while only defined + outside. In most cases cloning and sinking these operations in the + task is the simplest to avoid adding dependences. + + }]; +} + +def LowerDataflowTasks : Pass<"LowerDataflowTasks", "mlir::ModuleOp"> { + let summary = + "Outline the body of a DataflowTaskOp into a separate function which will serve as a task work function and lower the task graph to RT."; + + let description = [{ + This pass lowers a DataflowTaskGraph to the RT dialect, outlining + DataflowTaskOp into separate work functions and introducing the + necessary operations to communicate and synchronize execution via + futures. + }]; +} + + + +#endif diff --git a/compiler/include/zamalang/Dialect/RT/Analysis/CMakeLists.txt b/compiler/include/zamalang/Dialect/RT/Analysis/CMakeLists.txt new file mode 100644 index 000000000..8b3423d9e --- /dev/null +++ b/compiler/include/zamalang/Dialect/RT/Analysis/CMakeLists.txt @@ -0,0 +1,6 @@ +set(LLVM_TARGET_DEFINITIONS Autopar.td) +mlir_tablegen(Autopar.h.inc -gen-pass-decls -name Analysis) +mlir_tablegen(Autopar.capi.h.inc -gen-pass-capi-header --prefix Analysis) +mlir_tablegen(Autopar.capi.cpp.inc -gen-pass-capi-impl --prefix Analysis) +add_public_tablegen_target(AutoparPassIncGen) + diff --git a/compiler/include/zamalang/Dialect/RT/CMakeLists.txt b/compiler/include/zamalang/Dialect/RT/CMakeLists.txt index f33061b2d..4f7494893 100644 --- a/compiler/include/zamalang/Dialect/RT/CMakeLists.txt +++ b/compiler/include/zamalang/Dialect/RT/CMakeLists.txt @@ -1 +1,2 @@ +add_subdirectory(Analysis) add_subdirectory(IR) diff --git a/compiler/include/zamalang/Dialect/RT/IR/RTOps.td b/compiler/include/zamalang/Dialect/RT/IR/RTOps.td index 3a3f2066e..bca9edef9 100644 --- a/compiler/include/zamalang/Dialect/RT/IR/RTOps.td +++ b/compiler/include/zamalang/Dialect/RT/IR/RTOps.td @@ -3,6 +3,8 @@ include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/SymbolInterfaces.td" +include "mlir/Interfaces/DataLayoutInterfaces.td" include "zamalang/Dialect/RT/IR/RTDialect.td" include "zamalang/Dialect/RT/IR/RTTypes.td" @@ -16,9 +18,24 @@ def DataflowTaskOp : RT_Op<"dataflow_task", [SingleBlockImplicitTerminator<"Data let regions = (region AnyRegion:$body); + let builders = [ + OpBuilder<(ins + CArg<"TypeRange", "{}">: $resultTypes, + CArg<"ValueRange", "{}">: $operands, + CArg<"ArrayRef", "{}">: $attrs)> + ]; + let skipDefaultBuilders = 1; + let summary = "Dataflow task operation"; let description = [{ -`RT.dataflow_task` allows to specify a task that will be concurrently executed when their operands are ready. + +`RT.dataflow_task` allows to specify a task that will be concurrently +executed when their operands are ready. Operands are either the +results of computation in other `RT.dataflow_task` (dataflow +dependences) or obtained from the execution context (immediate +operands). Operands are synchronized using futures and, in the case +of immediate operands, copied when the task is created. Caution is +required when the operand is a pointer as no deep copy will occur. Example: @@ -43,6 +60,7 @@ func @test(%0 : i64): (i64, i64) { }) : (i64, i64) -> (i64, i64) return %3, %4 : (i64, i64) } +``` }]; } @@ -52,7 +70,7 @@ def DataflowYieldOp : RT_Op<"dataflow_yield", [ReturnLike, Terminator]> { let summary = "Dataflow yield operation"; let description = [{ `RT.dataflow_yield` is a special terminator operation for blocks inside the region -in `RT.dataflow_task`. It allows to specify the returns values of a `RT.dataflow_task`. +in `RT.dataflow_task`. It allows to specify the return values of a `RT.dataflow_task`. Example: @@ -64,4 +82,64 @@ Example: }]; } +def MakeReadyFutureOp : RT_Op<"make_ready_future"> { + let arguments = (ins AnyType: $input); + let results = (outs RT_Future: $output); + let summary = "Build a ready future."; + let description = [{ +Data passed to dataflow tasks must be encapsulated in futures, +including immediate operands. These must be converted into futures +using `RT.make_ready_future`. +}]; +} + +def AwaitFutureOp : RT_Op<"await_future"> { + let arguments = (ins RT_Future: $input); + let results = (outs AnyType: $output); + let summary = "Wait for a future and access its data."; + let description = [{ +The results of a dataflow task are always futures which could be +further used as inputs to subsequent tasks. When the result of a task +is needed in the outer execution context, the result future needs to +be synchronized and its data accessed using `RT.await_future`. +}]; +} + +def CreateAsyncTaskOp : RT_Op<"create_async_task"> { + let arguments = (ins SymbolRefAttr:$workfn, + Variadic:$list); + let results = (outs ); + let summary = "Create a dataflow task."; +} + +def DeallocateFutureOp : RT_Op<"deallocate_future"> { + let arguments = (ins RT_Future: $input); + let results = (outs ); +} + +def DeallocateFutureDataOp : RT_Op<"deallocate_future_data"> { + let arguments = (ins RT_Future: $input); + let results = (outs ); +} + +def BuildReturnPtrPlaceholderOp : RT_Op<"build_return_ptr_placeholder"> { + let arguments = (ins ); + let results = (outs RT_Pointer: $output); +} + +def DerefReturnPtrPlaceholderOp : RT_Op<"deref_return_ptr_placeholder"> { + let arguments = (ins RT_Pointer: $input); + let results = (outs RT_Future: $output); +} + +def DerefWorkFunctionArgumentPtrPlaceholderOp : RT_Op<"deref_work_function_argument_ptr_placeholder"> { + let arguments = (ins RT_Pointer: $input); + let results = (outs AnyType: $output); +} + +def WorkFunctionReturnOp : RT_Op<"work_function_return"> { + let arguments = (ins AnyType:$in, AnyType:$out); + let results = (outs ); +} + #endif diff --git a/compiler/include/zamalang/Dialect/RT/IR/RTTypes.td b/compiler/include/zamalang/Dialect/RT/IR/RTTypes.td index 13cb887e4..7d7d577a5 100644 --- a/compiler/include/zamalang/Dialect/RT/IR/RTTypes.td +++ b/compiler/include/zamalang/Dialect/RT/IR/RTTypes.td @@ -46,7 +46,40 @@ def RT_Future : RT_Type<"Future"> { return Type(); return get($_ctxt, elementType); }]; - //let genVerifyDecl = 1; +} + +def RT_Pointer : RT_Type<"Pointer"> { + let mnemonic = "rtptr"; + + let summary = "Pointer to a parameterized element type"; + + let description = [{ + }]; + + let parameters = (ins "Type":$elementType); + + let builders = [ + TypeBuilderWithInferredContext<(ins "Type":$elementType), [{ + return $_get(elementType.getContext(), elementType); + }]> + ]; + + let printer = [{ + $_printer << "rtptr<"; + $_printer.printType(getElementType()); + $_printer << ">"; + }]; + + let parser = [{ + if ($_parser.parseLess()) + return Type(); + Type elementType; + if ($_parser.parseType(elementType)) + return Type(); + if ($_parser.parseGreater()) + return Type(); + return get($_ctxt, elementType); + }]; } #endif diff --git a/compiler/include/zamalang/Support/CompilerEngine.h b/compiler/include/zamalang/Support/CompilerEngine.h index 354107b91..a4a566e5f 100644 --- a/compiler/include/zamalang/Support/CompilerEngine.h +++ b/compiler/include/zamalang/Support/CompilerEngine.h @@ -122,7 +122,7 @@ public: CompilerEngine(std::shared_ptr compilationContext) : overrideMaxEintPrecision(), overrideMaxMANP(), clientParametersFuncName(), verifyDiagnostics(false), - generateClientParameters(false), + autoParallelize(false), generateClientParameters(false), enablePass([](mlir::Pass *pass) { return true; }), compilationContext(compilationContext) {} @@ -146,6 +146,7 @@ public: void setMaxEintPrecision(size_t v); void setMaxMANP(size_t v); void setVerifyDiagnostics(bool v); + void setAutoParallelize(bool v); void setGenerateClientParameters(bool v); void setClientParametersFuncName(const llvm::StringRef &name); void setHLFHELinalgTileSizes(llvm::ArrayRef sizes); @@ -158,6 +159,7 @@ protected: llvm::Optional> hlfhelinalgTileSizes; bool verifyDiagnostics; + bool autoParallelize; bool generateClientParameters; std::function enablePass; diff --git a/compiler/include/zamalang/Support/Pipeline.h b/compiler/include/zamalang/Support/Pipeline.h index 3bab5ee85..e99c2222d 100644 --- a/compiler/include/zamalang/Support/Pipeline.h +++ b/compiler/include/zamalang/Support/Pipeline.h @@ -12,6 +12,9 @@ namespace mlir { namespace zamalang { namespace pipeline { +mlir::LogicalResult autopar(mlir::MLIRContext &context, mlir::ModuleOp &module, + std::function enablePass); + llvm::Expected> getFHEConstraintsFromHLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module, std::function enablePass); diff --git a/compiler/lib/Conversion/HLFHEToMidLFHE/HLFHEToMidLFHE.cpp b/compiler/lib/Conversion/HLFHEToMidLFHE/HLFHEToMidLFHE.cpp index d75d1067a..c217c0c63 100644 --- a/compiler/lib/Conversion/HLFHEToMidLFHE/HLFHEToMidLFHE.cpp +++ b/compiler/lib/Conversion/HLFHEToMidLFHE/HLFHEToMidLFHE.cpp @@ -11,6 +11,7 @@ #include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h" #include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h" #include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h" +#include "zamalang/Dialect/RT/IR/RTOps.h" namespace { struct HLFHEToMidLFHEPass : public HLFHEToMidLFHEBase { @@ -92,6 +93,12 @@ void HLFHEToMidLFHEPass::runOnOperation() { converter); mlir::populateFuncOpTypeConversionPattern(patterns, converter); + // Conversion of RT Dialect Ops + patterns.add>(patterns.getContext(), converter); + mlir::zamalang::addDynamicallyLegalTypeOp( + target, converter); + // Apply conversion if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) { this->signalPassFailure(); diff --git a/compiler/lib/Conversion/LowLFHEUnparametrize/LowLFHEUnparametrize.cpp b/compiler/lib/Conversion/LowLFHEUnparametrize/LowLFHEUnparametrize.cpp index 32478330c..219a7b415 100644 --- a/compiler/lib/Conversion/LowLFHEUnparametrize/LowLFHEUnparametrize.cpp +++ b/compiler/lib/Conversion/LowLFHEUnparametrize/LowLFHEUnparametrize.cpp @@ -7,6 +7,7 @@ #include "zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h" #include "zamalang/Dialect/LowLFHE/IR/LowLFHEOps.h" #include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h" +#include "zamalang/Dialect/RT/IR/RTOps.h" #include "zamalang/Support/Constants.h" /// LowLFHEUnparametrizeTypeConverter is a type converter that unparametrize @@ -123,6 +124,12 @@ void LowLFHEUnparametrizePass::runOnOperation() { patterns.getContext(), converter); mlir::zamalang::addDynamicallyLegalTypeOp(target, converter); + // Conversion of RT Dialect Ops + patterns.add>(patterns.getContext(), converter); + mlir::zamalang::addDynamicallyLegalTypeOp( + target, converter); + // Apply conversion if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) { this->signalPassFailure(); diff --git a/compiler/lib/Conversion/MLIRLowerableDialectsToLLVM/MLIRLowerableDialectsToLLVM.cpp b/compiler/lib/Conversion/MLIRLowerableDialectsToLLVM/MLIRLowerableDialectsToLLVM.cpp index 0fba4a7bb..80097644d 100644 --- a/compiler/lib/Conversion/MLIRLowerableDialectsToLLVM/MLIRLowerableDialectsToLLVM.cpp +++ b/compiler/lib/Conversion/MLIRLowerableDialectsToLLVM/MLIRLowerableDialectsToLLVM.cpp @@ -19,6 +19,8 @@ #include "zamalang/Conversion/Passes.h" #include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h" +#include "zamalang/Dialect/RT/Analysis/Autopar.h" +#include "zamalang/Dialect/RT/IR/RTTypes.h" namespace { struct MLIRLowerableDialectsToLLVMPass @@ -52,6 +54,7 @@ void MLIRLowerableDialectsToLLVMPass::runOnOperation() { // Setup the set of the patterns rewriter. At this point we want to // convert the `scf` operations to `std` and `std` operations to `llvm`. mlir::RewritePatternSet patterns(&getContext()); + mlir::zamalang::populateRTToLLVMConversionPatterns(typeConverter, patterns); mlir::populateStdToLLVMConversionPatterns(typeConverter, patterns); mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter, patterns); @@ -72,10 +75,28 @@ MLIRLowerableDialectsToLLVMPass::convertTypes(mlir::Type type) { type.isa() || type.isa() || type.isa() || - type.isa()) { + type.isa() || + type.isa()) { return mlir::LLVM::LLVMPointerType::get( mlir::IntegerType::get(type.getContext(), 64)); } + if (type.isa()) { + mlir::LowerToLLVMOptions options(type.getContext()); + mlir::LLVMTypeConverter typeConverter(type.getContext(), options); + typeConverter.addConversion(convertTypes); + typeConverter.addConversion( + [&](mlir::zamalang::LowLFHE::PlaintextType type) { + return mlir::IntegerType::get(type.getContext(), 64); + }); + typeConverter.addConversion( + [&](mlir::zamalang::LowLFHE::CleartextType type) { + return mlir::IntegerType::get(type.getContext(), 64); + }); + mlir::Type subtype = + type.dyn_cast().getElementType(); + mlir::Type convertedSubtype = typeConverter.convertType(subtype); + return mlir::LLVM::LLVMPointerType::get(convertedSubtype); + } return llvm::None; } diff --git a/compiler/lib/Conversion/MidLFHEGlobalParametrization/MidLFHEGlobalParametrization.cpp b/compiler/lib/Conversion/MidLFHEGlobalParametrization/MidLFHEGlobalParametrization.cpp index ec6aabbcc..bbe6e0439 100644 --- a/compiler/lib/Conversion/MidLFHEGlobalParametrization/MidLFHEGlobalParametrization.cpp +++ b/compiler/lib/Conversion/MidLFHEGlobalParametrization/MidLFHEGlobalParametrization.cpp @@ -7,6 +7,7 @@ #include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h" #include "zamalang/Dialect/MidLFHE/IR/MidLFHEOps.h" #include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h" +#include "zamalang/Dialect/RT/IR/RTOps.h" #include "zamalang/Support/Constants.h" namespace { @@ -300,6 +301,12 @@ void MidLFHEGlobalParametrizationPass::runOnOperation() { mlir::zamalang::populateWithTensorTypeConverterPatterns(patterns, target, converter); + // Conversion of RT Dialect Ops + patterns.add>(patterns.getContext(), converter); + mlir::zamalang::addDynamicallyLegalTypeOp< + mlir::zamalang::RT::DataflowTaskOp>(target, converter); + // Apply conversion if (mlir::applyPartialConversion(op, target, std::move(patterns)) .failed()) { diff --git a/compiler/lib/Conversion/MidLFHEToLowLFHE/MidLFHEToLowLFHE.cpp b/compiler/lib/Conversion/MidLFHEToLowLFHE/MidLFHEToLowLFHE.cpp index 934bf0cf5..12a56f090 100644 --- a/compiler/lib/Conversion/MidLFHEToLowLFHE/MidLFHEToLowLFHE.cpp +++ b/compiler/lib/Conversion/MidLFHEToLowLFHE/MidLFHEToLowLFHE.cpp @@ -11,6 +11,7 @@ #include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h" #include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h" #include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h" +#include "zamalang/Dialect/RT/IR/RTOps.h" namespace { struct MidLFHEToLowLFHEPass @@ -89,6 +90,12 @@ void MidLFHEToLowLFHEPass::runOnOperation() { converter); mlir::populateFuncOpTypeConversionPattern(patterns, converter); + // Conversion of RT Dialect Ops + patterns.add>(patterns.getContext(), converter); + mlir::zamalang::addDynamicallyLegalTypeOp( + target, converter); + // Apply conversion if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) { this->signalPassFailure(); diff --git a/compiler/lib/Dialect/RT/Analysis/BufferizeDataflowTaskOps.cpp b/compiler/lib/Dialect/RT/Analysis/BufferizeDataflowTaskOps.cpp new file mode 100644 index 000000000..24ec02049 --- /dev/null +++ b/compiler/lib/Dialect/RT/Analysis/BufferizeDataflowTaskOps.cpp @@ -0,0 +1,119 @@ +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#define GEN_PASS_CLASSES +#include + +namespace mlir { +namespace zamalang { + +namespace { +class BufferizeDataflowYieldOp + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(RT::DataflowYieldOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + RT::DataflowYieldOp::Adaptor transformed(operands); + rewriter.replaceOpWithNewOp(op, mlir::TypeRange(), + transformed.getOperands()); + return success(); + } +}; +} // namespace + +namespace { +class BufferizeDataflowTaskOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(RT::DataflowTaskOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + RT::DataflowTaskOp::Adaptor transformed(operands); + mlir::OpBuilder::InsertionGuard guard(rewriter); + + SmallVector newResults; + (void)getTypeConverter()->convertTypes(op.getResultTypes(), newResults); + auto newop = rewriter.create(op.getLoc(), newResults, + transformed.getOperands()); + // We cannot clone here as cloned ops must be legalized (so this + // would break on the YieldOp). Instead use mergeBlocks which + // moves the ops instead of cloning. + rewriter.mergeBlocks(op.getBody(), newop.getBody(), + newop.getBody()->getArguments()); + // Because of previous bufferization there are buffer cast ops + // that have been generated for the previously tensor results of + // some tasks. These cannot just be replaced directly as the + // task's results would still be live. + for (auto res : llvm::enumerate(op.getResults())) { + // If this result is getting bufferized ... + if (res.value().getType() != + getTypeConverter()->convertType(res.value().getType())) { + for (auto &use : llvm::make_early_inc_range(res.value().getUses())) { + // ... and its uses are in `BufferCastOp`s, then we + // replace further uses of the buffer cast. + if (isa(use.getOwner())) { + rewriter.replaceOp(use.getOwner(), {newop.getResult(res.index())}); + } + } + } + } + rewriter.replaceOp(op, {newop.getResults()}); + return success(); + } +}; +} // namespace + +void populateRTBufferizePatterns(BufferizeTypeConverter &typeConverter, + RewritePatternSet &patterns) { + patterns.add( + typeConverter, patterns.getContext()); +} + +namespace { +// For documentation see Autopar.td +struct BufferizeDataflowTaskOpsPass + : public BufferizeDataflowTaskOpsBase { + + void runOnOperation() override { + auto module = getOperation(); + auto *context = &getContext(); + BufferizeTypeConverter typeConverter; + RewritePatternSet patterns(context); + ConversionTarget target(*context); + + populateRTBufferizePatterns(typeConverter, patterns); + + // Forbid all RT ops that still use/return tensors + target.addDynamicallyLegalDialect( + [&](Operation *op) { return typeConverter.isLegal(op); }); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) + signalPassFailure(); + } + + BufferizeDataflowTaskOpsPass(bool debug) : debug(debug){}; + +protected: + bool debug; +}; +} // end anonymous namespace + +std::unique_ptr createBufferizeDataflowTaskOpsPass(bool debug) { + return std::make_unique(debug); +} +} // namespace zamalang +} // namespace mlir diff --git a/compiler/lib/Dialect/RT/Analysis/BuildDataflowTaskGraph.cpp b/compiler/lib/Dialect/RT/Analysis/BuildDataflowTaskGraph.cpp new file mode 100644 index 000000000..bfb23e6b5 --- /dev/null +++ b/compiler/lib/Dialect/RT/Analysis/BuildDataflowTaskGraph.cpp @@ -0,0 +1,266 @@ +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define GEN_PASS_CLASSES +#include + +namespace mlir { +namespace zamalang { + +namespace { + +// TODO: adjust these two functions based on cost model +static bool isCandidateForTask(Operation *op) { + return isa(op); +} + +// Identify operations that are beneficial to sink into tasks. These +// operations must not have side-effects and not be `isCandidateForTask` +static bool isSinkingBeneficiary(Operation *op) { + return isa(op); +} + +static bool +extractBeneficiaryOps(Operation *op, SetVector existingDependencies, + SetVector &beneficiaryOps, + llvm::SmallPtrSetImpl &availableValues) { + if (beneficiaryOps.count(op)) + return true; + + if (!isSinkingBeneficiary(op)) + return false; + + for (Value operand : op->getOperands()) { + // It is already visible in the kernel, keep going. + if (availableValues.count(operand)) + continue; + // Else check whether it can be made available via sinking or already is a + // dependency. + Operation *definingOp = operand.getDefiningOp(); + if ((!definingOp || + !extractBeneficiaryOps(definingOp, existingDependencies, + beneficiaryOps, availableValues)) && + !existingDependencies.count(operand)) + return false; + } + // We will sink the operation, mark its results as now available. + beneficiaryOps.insert(op); + for (Value result : op->getResults()) + availableValues.insert(result); + return true; +} + +LogicalResult sinkOperationsIntoDFTask(RT::DataflowTaskOp taskOp) { + Region &taskOpBody = taskOp.body(); + + // Identify uses from values defined outside of the scope. + SetVector sinkCandidates; + getUsedValuesDefinedAbove(taskOpBody, sinkCandidates); + + SetVector toBeSunk; + llvm::SmallPtrSet availableValues; + for (Value operand : sinkCandidates) { + Operation *operandOp = operand.getDefiningOp(); + if (!operandOp) + continue; + extractBeneficiaryOps(operandOp, sinkCandidates, toBeSunk, availableValues); + } + + // Insert operations so that the defs get cloned before uses. + BlockAndValueMapping map; + OpBuilder builder(taskOpBody); + for (Operation *op : toBeSunk) { + OpBuilder::InsertionGuard guard(builder); + Operation *clonedOp = builder.clone(*op, map); + for (auto pair : llvm::zip(op->getResults(), clonedOp->getResults())) + replaceAllUsesInRegionWith(std::get<0>(pair), std::get<1>(pair), + taskOp.body()); + // Once this is sunk, remove all operands of the DFT covered by this + for (auto result : op->getResults()) + for (auto operand : llvm::enumerate(taskOp.getOperands())) + if (operand.value() == result) { + taskOp->eraseOperand(operand.index()); + // Once removed, we assume there are no duplicates + break; + } + } + return success(); +} + +// For documentation see Autopar.td +struct BuildDataflowTaskGraphPass + : public BuildDataflowTaskGraphBase { + + void runOnOperation() override { + auto module = getOperation(); + + module.walk([&](mlir::FuncOp func) { + if (!func->getAttr("_dfr_work_function_attribute")) + func.walk( + [&](mlir::Operation *childOp) { this->processOperation(childOp); }); + + // Perform simplifications, in particular DCE here in case some + // of the operations sunk in tasks are no longer needed in the + // main function. If the function fails it only means that + // nothing was simplified. Doing this here - rather than later + // in the compilation pipeline - allows to take advantage of + // higher level semantics which we can attach to operations + // (e.g., NoSideEffect on HLFHE::ZeroEintOp). + IRRewriter rewriter(func->getContext()); + (void)mlir::simplifyRegions(rewriter, func->getRegions()); + }); + } + BuildDataflowTaskGraphPass(bool debug) : debug(debug){}; + +protected: + void processOperation(mlir::Operation *op) { + if (isCandidateForTask(op)) { + BlockAndValueMapping map; + Region &opBody = getOperation().body(); + OpBuilder builder(opBody); + + // Create a DFTask for this operation + builder.setInsertionPointAfter(op); + auto dftop = builder.create( + op->getLoc(), op->getResultTypes(), op->getOperands()); + // Add the operation to the task + OpBuilder tbbuilder(dftop.body()); + Operation *clonedOp = tbbuilder.clone(*op, map); + // Add sinkable operations to the task + assert(!failed(sinkOperationsIntoDFTask(dftop)) && + "Failing to sink operations into DFT"); + + // Add terminator + tbbuilder.create(dftop.getLoc(), mlir::TypeRange(), + op->getResults()); + // Replace the uses of defined values + for (auto pair : llvm::zip(op->getResults(), clonedOp->getResults())) + replaceAllUsesInRegionWith(std::get<0>(pair), std::get<1>(pair), + dftop.body()); + // Replace uses of the values defined by the task + for (auto pair : llvm::zip(op->getResults(), dftop->getResults())) + replaceAllUsesInRegionWith(std::get<0>(pair), std::get<1>(pair), + opBody); + // Once uses are re-targeted to the task, delete the operation + op->erase(); + } + } + + bool debug; +}; +} // end anonymous namespace + +std::unique_ptr createBuildDataflowTaskGraphPass(bool debug) { + return std::make_unique(debug); +} + +namespace { +// Marker to avoid infinite recursion of the rewriting pattern +static const mlir::StringLiteral kTransformMarker = + "_internal_RT_FixDataflowTaskOpInputsPattern_marker__"; + +class FixDataflowTaskOpInputsPattern + : public mlir::OpRewritePattern { +public: + FixDataflowTaskOpInputsPattern(mlir::MLIRContext *context) + : mlir::OpRewritePattern( + context, ::mlir::zamalang::DEFAULT_PATTERN_BENEFIT) {} + + LogicalResult + matchAndRewrite(RT::DataflowTaskOp op, + mlir::PatternRewriter &rewriter) const override { + mlir::OpBuilder::InsertionGuard guard(rewriter); + + if (op->hasAttr(kTransformMarker)) + return failure(); + + // Identify which values need to be passed as dependences to the + // task - this is very conservative and will add constants, index + // operations, etc. A simplification will occur later. + SetVector deps; + getUsedValuesDefinedAbove(op.body(), deps); + auto newop = rewriter.create( + op.getLoc(), op.getResultTypes(), deps.getArrayRef()); + rewriter.mergeBlocks(op.getBody(), newop.getBody(), + newop.getBody()->getArguments()); + rewriter.replaceOp(op, {newop.getResults()}); + + // Mark this as processed to prevent infinite loop + newop.getOperation()->setAttr(kTransformMarker, rewriter.getUnitAttr()); + return success(); + } +}; +} // namespace + +namespace { +// For documentation see Autopar.td +struct FixupDataflowTaskOpsPass + : public FixupDataflowTaskOpsBase { + + void runOnOperation() override { + auto module = getOperation(); + auto *context = &getContext(); + + RewritePatternSet patterns(context); + patterns.add(context); + + if (mlir::applyPatternsAndFoldGreedily(module, std::move(patterns)) + .failed()) + signalPassFailure(); + + // Clear mark and sink any newly created constants or indexing + // operations, etc. to reduce the number of input dependences to + // the task + module->walk([](RT::DataflowTaskOp op) { + op.getOperation()->removeAttr(kTransformMarker); + assert(!failed(sinkOperationsIntoDFTask(op)) && + "Failing to sink operations into DFT"); + }); + } + + FixupDataflowTaskOpsPass(bool debug) : debug(debug){}; + +protected: + bool debug; +}; +} // end anonymous namespace + +std::unique_ptr createFixupDataflowTaskOpsPass(bool debug) { + return std::make_unique(debug); +} + +} // end namespace zamalang +} // end namespace mlir diff --git a/compiler/lib/Dialect/RT/Analysis/CMakeLists.txt b/compiler/lib/Dialect/RT/Analysis/CMakeLists.txt new file mode 100644 index 000000000..4ee50e1ad --- /dev/null +++ b/compiler/lib/Dialect/RT/Analysis/CMakeLists.txt @@ -0,0 +1,18 @@ +add_mlir_library(RTDialectAnalysis + BufferizeDataflowTaskOps.cpp + BuildDataflowTaskGraph.cpp + LowerDataflowTasksToRT.cpp + LowerRTToLLVMDFRCallsConversionPatterns.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/zamalang/Dialect/RT + + DEPENDS + RTDialect + AutoparPassIncGen + + LINK_LIBS PUBLIC + MLIRIR + RTDialect) + +target_link_libraries(RTDialectAnalysis PUBLIC MLIRIR) diff --git a/compiler/lib/Dialect/RT/Analysis/LowerDataflowTasksToRT.cpp b/compiler/lib/Dialect/RT/Analysis/LowerDataflowTasksToRT.cpp new file mode 100644 index 000000000..5d81e0dad --- /dev/null +++ b/compiler/lib/Dialect/RT/Analysis/LowerDataflowTasksToRT.cpp @@ -0,0 +1,337 @@ +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define GEN_PASS_CLASSES +#include + +namespace mlir { +namespace zamalang { + +namespace { + +static FuncOp outlineWorkFunction(RT::DataflowTaskOp DFTOp, + StringRef workFunctionName) { + Location loc = DFTOp.getLoc(); + OpBuilder builder(DFTOp.getContext()); + Region &DFTOpBody = DFTOp.body(); + OpBuilder::InsertionGuard guard(builder); + + // Instead of outlining with the same operands/results, we pass all + // results as operands as well. For now we preserve the results' + // types, which will be changed to use an indirection when lowering. + SmallVector operandTypes; + operandTypes.reserve(DFTOp.getNumOperands() + DFTOp.getNumResults()); + for (Value operand : DFTOp.getOperands()) + operandTypes.push_back(RT::PointerType::get(operand.getType())); + for (Value res : DFTOp.getResults()) + operandTypes.push_back(RT::PointerType::get(res.getType())); + + FunctionType type = FunctionType::get(DFTOp.getContext(), operandTypes, {}); + auto outlinedFunc = builder.create(loc, workFunctionName, type); + outlinedFunc->setAttr("_dfr_work_function_attribute", builder.getUnitAttr()); + Region &outlinedFuncBody = outlinedFunc.body(); + Block *outlinedEntryBlock = new Block; + outlinedEntryBlock->addArguments(type.getInputs()); + outlinedFuncBody.push_back(outlinedEntryBlock); + + BlockAndValueMapping map; + Block &entryBlock = outlinedFuncBody.front(); + builder.setInsertionPointToStart(&entryBlock); + for (auto operand : llvm::enumerate(DFTOp.getOperands())) { + // Add deref of arguments and remap to operands in the body + auto derefdop = + builder.create( + DFTOp.getLoc(), operand.value().getType(), + entryBlock.getArgument(operand.index())); + map.map(operand.value(), derefdop->getResult(0)); + } + DFTOpBody.cloneInto(&outlinedFuncBody, map); + + Block &DFTOpEntry = DFTOpBody.front(); + Block *clonedDFTOpEntry = map.lookup(&DFTOpEntry); + builder.setInsertionPointToEnd(&entryBlock); + builder.create(loc, clonedDFTOpEntry); + + // TODO: we use a WorkFunctionReturnOp to tie return to the + // corresponding argument. This can be lowered to a copy/deref for + // shared memory and pointers, but needs to be handled for + // distributed memory. + outlinedFunc.walk([&](RT::DataflowYieldOp op) { + OpBuilder replacer(op); + int output_offset = DFTOp.getNumOperands(); + for (auto ret : llvm::enumerate(op.getOperands())) + replacer.create( + op.getLoc(), ret.value(), + outlinedFunc.getArgument(ret.index() + output_offset)); + replacer.create(op.getLoc()); + op.erase(); + }); + return outlinedFunc; +} + +static void replaceAllUsesInDFTsInRegionWith(Value orig, Value replacement, + Region ®ion) { + for (auto &use : llvm::make_early_inc_range(orig.getUses())) { + if (isa(use.getOwner()) && + region.isAncestor(use.getOwner()->getParentRegion())) + use.set(replacement); + } +} +static void replaceAllUsesNotInDFTsInRegionWith(Value orig, Value replacement, + Region ®ion) { + for (auto &use : llvm::make_early_inc_range(orig.getUses())) { + if (!isa(use.getOwner()) && + use.getOwner()->getParentOfType() == nullptr && + region.isAncestor(use.getOwner()->getParentRegion())) + use.set(replacement); + } +} + +// TODO: Fix type sizes. For now we're using some default values. +static mlir::Value getSizeInBytes(Value val, Location loc, OpBuilder builder) { + DataLayout dataLayout = DataLayout::closest(val.getDefiningOp()); + Type type = (val.getType().isa()) + ? val.getType().dyn_cast().getElementType() + : val.getType(); + + // In the case of memref, we need to determine how much space + // (conservatively) we need to store the memref itself. Overshooting + // by a few bytes should not be an issue, so the main thing is to + // properly account for the rank. + if (type.isa()) { + // Space for the allocated and aligned pointers, and offset + Value ptrs_offset = + builder.create(loc, builder.getI64IntegerAttr(24)); + // For the sizes and shapes arrays, we need 2*8 = 16 times the rank in bytes + Value multiplier = + builder.create(loc, builder.getI64IntegerAttr(16)); + unsigned _rank = type.dyn_cast().getRank(); + Value rank = builder.create( + loc, builder.getI64IntegerAttr(_rank)); + Value sizes_shapes = builder.create(loc, rank, multiplier); + Value result = builder.create(loc, ptrs_offset, sizes_shapes); + return result; + } + + // Unranked memrefs should be lowered to just pointer + size, so we need 16 + // bytes. + if (type.isa()) + return builder.create(loc, + builder.getI64IntegerAttr(16)); + + // FHE types are converted to pointers, so we take their size as 8 + // bytes until we can get the actual size of the actual types. + if (type.isa() || + type.isa() || + type.isa() || + type.isa() || + type.isa() || + type.isa() || + type.isa()) + return builder.create(loc, builder.getI64IntegerAttr(8)); + + // For all other types, get type size. + return builder.create( + loc, builder.getI64IntegerAttr(dataLayout.getTypeSize(type))); +} + +static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp, FuncOp workFunction) { + DataLayout dataLayout = DataLayout::closest(DFTOp); + Region &opBody = DFTOp->getParentOfType().body(); + BlockAndValueMapping map; + OpBuilder builder(DFTOp); + + // First identify DFT operands that are not futures and are not + // defined by another DFT. These need to be made into futures and + // propagated to all other DFTs. We can allow PRE to eliminate the + // previous definitions if there are no non-future type uses. + builder.setInsertionPoint(DFTOp); + for (Value val : DFTOp.getOperands()) { + if (!val.getType().isa()) { + Type futType = RT::FutureType::get(val.getType()); + auto mrf = + builder.create(DFTOp.getLoc(), futType, val); + map.map(mrf->getResult(0), val); + replaceAllUsesInDFTsInRegionWith(val, mrf->getResult(0), opBody); + } + } + + // Second generate a CreateAsyncTaskOp that will replace the + // DataflowTaskOp. This also includes the necessary handling of + // operands and results (conversion to/from futures and propagation). + SmallVector catOperands; + int size = 3 + DFTOp.getNumResults() * 2 + DFTOp.getNumOperands() * 2; + catOperands.reserve(size); + auto fnptr = builder.create( + DFTOp.getLoc(), workFunction.getType(), + SymbolRefAttr::get(builder.getContext(), workFunction.getName())); + auto numIns = builder.create( + DFTOp.getLoc(), builder.getI64IntegerAttr(DFTOp.getNumOperands())); + auto numOuts = builder.create( + DFTOp.getLoc(), builder.getI64IntegerAttr(DFTOp.getNumResults())); + catOperands.push_back(fnptr.getResult()); + catOperands.push_back(numIns.getResult()); + catOperands.push_back(numOuts.getResult()); + for (auto operand : DFTOp.getOperands()) { + catOperands.push_back(operand); + catOperands.push_back(getSizeInBytes(operand, DFTOp.getLoc(), builder)); + } + + // We need to adjust the results for the CreateAsyncTaskOp which + // are the work function's returns through pointers passed as + // parameters. As this is not supported within MLIR - and mostly + // unsupported even in the LLVMIR Dialect - this needs to use two + // placeholders for each output, before and after the + // CreateAsyncTaskOp. + for (auto result : DFTOp.getResults()) { + Type futType = RT::PointerType::get(RT::FutureType::get(result.getType())); + auto brpp = builder.create(DFTOp.getLoc(), + futType); + map.map(result, brpp->getResult(0)); + catOperands.push_back(brpp->getResult(0)); + catOperands.push_back(getSizeInBytes(result, DFTOp.getLoc(), builder)); + } + builder.create( + DFTOp.getLoc(), + SymbolRefAttr::get(builder.getContext(), workFunction.getName()), + catOperands); + + // Third identify results of this DFT that are not used *only* in + // other DFTs as those will need to be waited on explicitly. + // We also create the DerefReturnPtrPlaceholderOp after the + // CreateAsyncTaskOp. These also need propagating. + for (auto result : DFTOp.getResults()) { + Type futType = RT::FutureType::get(result.getType()); + Value futptr = map.lookupOrNull(result); + assert(futptr); + auto drpp = builder.create( + DFTOp.getLoc(), futType, futptr); + replaceAllUsesInDFTsInRegionWith(result, drpp->getResult(0), opBody); + + for (auto &use : llvm::make_early_inc_range(result.getUses())) { + if (!isa(use.getOwner()) && + use.getOwner()->getParentOfType() == nullptr) { + // Wait for this future + // TODO: the wait function should ideally + // be issued as late as possible, but need to identify which + // use comes first. + auto af = builder.create( + DFTOp.getLoc(), result.getType(), drpp.getResult()); + replaceAllUsesNotInDFTsInRegionWith(result, af->getResult(0), opBody); + // We only need to to this once, propagation will hit all + // other uses + break; + } + } + // All leftover uses (i.e. those within DFTs should use the future) + replaceAllUsesInRegionWith(result, futptr, opBody); + } + + // Finally erase the DFT. + DFTOp.erase(); +} + +// For documentation see Autopar.td +struct LowerDataflowTasksPass + : public LowerDataflowTasksBase { + + void runOnOperation() override { + auto module = getOperation(); + + module.walk([&](mlir::FuncOp func) { + int wfn_id = 0; + + // TODO: For now do not attempt to use nested parallelism. + if (func->getAttr("_dfr_work_function_attribute")) + return; + + SymbolTable symbolTable = mlir::SymbolTable::getNearestSymbolTable(func); + std::vector> outliningMap; + + func.walk([&](RT::DataflowTaskOp op) { + auto workFunctionName = Twine("_dfr_DFT_work_function__") + + Twine(op->getParentOfType().getName()) + + Twine(wfn_id++); + FuncOp outlinedFunc = outlineWorkFunction(op, workFunctionName.str()); + outliningMap.push_back( + std::pair(op, outlinedFunc)); + symbolTable.insert(outlinedFunc); + return WalkResult::advance(); + }); + + // Lower the DF task ops to RT dialect ops. + for (auto mapping : outliningMap) + lowerDataflowTaskOp(mapping.first, mapping.second); + + // Issue _dfr_start/stop calls for this function + if (!outliningMap.empty()) { + OpBuilder builder(func.body()); + builder.setInsertionPointToStart(&func.body().front()); + auto dfrStartFunOp = mlir::LLVM::lookupOrCreateFn( + func->getParentOfType(), "_dfr_start", {}, + LLVM::LLVMVoidType::get(func->getContext())); + builder.create(func.getLoc(), dfrStartFunOp, + mlir::ValueRange(), + ArrayRef()); + + builder.setInsertionPoint(func.body().back().getTerminator()); + auto dfrStopFunOp = mlir::LLVM::lookupOrCreateFn( + func->getParentOfType(), "_dfr_stop", {}, + LLVM::LLVMVoidType::get(func->getContext())); + builder.create(func.getLoc(), dfrStopFunOp, + mlir::ValueRange(), + ArrayRef()); + } + }); + } + LowerDataflowTasksPass(bool debug) : debug(debug){}; + +protected: + bool debug; +}; +} // end anonymous namespace + +std::unique_ptr createLowerDataflowTasksPass(bool debug) { + return std::make_unique(debug); +} + +} // end namespace zamalang +} // end namespace mlir diff --git a/compiler/lib/Dialect/RT/Analysis/LowerRTToLLVMDFRCallsConversionPatterns.cpp b/compiler/lib/Dialect/RT/Analysis/LowerRTToLLVMDFRCallsConversionPatterns.cpp new file mode 100644 index 000000000..559f34565 --- /dev/null +++ b/compiler/lib/Dialect/RT/Analysis/LowerRTToLLVMDFRCallsConversionPatterns.cpp @@ -0,0 +1,310 @@ +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define GEN_PASS_CLASSES +#include + +namespace mlir { +namespace zamalang { + +namespace { + +mlir::Type getVoidPtrI64Type(ConversionPatternRewriter &rewriter) { + return mlir::LLVM::LLVMPointerType::get( + mlir::IntegerType::get(rewriter.getContext(), 64)); +} + +LLVM::LLVMFuncOp getOrInsertFuncOpDecl(mlir::Operation *op, + llvm::StringRef funcName, + LLVM::LLVMFunctionType funcType, + ConversionPatternRewriter &rewriter) { + // Check if the function is already in the symbol table + auto module = op->getParentOfType(); + auto funcOp = module.lookupSymbol(funcName); + if (!funcOp) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + funcOp = + rewriter.create(op->getLoc(), funcName, funcType); + funcOp.setPrivate(); + } else { + if (!funcOp.isPrivate()) { + op->emitError() + << "the function \"" << funcName + << "\" conflicts with the Dataflow Runtime API, please rename."; + return nullptr; + } + } + return funcOp; +} + +// This function is only needed for debug purposes to inspect values +// in the generated code - it is therefore not generally in use. +LLVM_ATTRIBUTE_UNUSED void +insertPrintDebugCall(ConversionPatternRewriter &rewriter, mlir::Operation *op, + Value val) { + OpBuilder::InsertionGuard guard(rewriter); + auto printFnType = LLVM::LLVMFunctionType::get( + LLVM::LLVMVoidType::get(rewriter.getContext()), {}, /*isVariadic=*/true); + auto printFnOp = + getOrInsertFuncOpDecl(op, "_dfr_print_debug", printFnType, rewriter); + rewriter.create(op->getLoc(), printFnOp, val); +} + +struct MakeReadyFutureOpInterfaceLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + mlir::LogicalResult + matchAndRewrite(RT::MakeReadyFutureOp mrfOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + + RT::MakeReadyFutureOp::Adaptor transformed(operands); + OpBuilder::InsertionGuard guard(rewriter); + + // Normally this function takes a pointer as parameter + auto mrfFuncType = LLVM::LLVMFunctionType::get(getVoidPtrI64Type(rewriter), + {}, /*isVariadic=*/true); + auto mrfFuncOp = getOrInsertFuncOpDecl(mrfOp, "_dfr_make_ready_future", + mrfFuncType, rewriter); + + // In order to support non pointer types, we need to allocate + // explicitly space that we can reference as a base for the + // future. + auto allocFuncOp = mlir::LLVM::lookupOrCreateMallocFn( + mrfOp->getParentOfType(), getIndexType()); + auto sizeBytes = getSizeInBytes( + mrfOp.getLoc(), transformed.getOperands().getTypes().front(), rewriter); + auto results = mlir::LLVM::createLLVMCall( + rewriter, mrfOp.getLoc(), allocFuncOp, {sizeBytes}, getVoidPtrType()); + Value allocatedPtr = rewriter.create( + mrfOp.getLoc(), + mlir::LLVM::LLVMPointerType::get( + transformed.getOperands().getTypes().front()), + results[0]); + rewriter.create( + mrfOp.getLoc(), transformed.getOperands().front(), allocatedPtr); + rewriter.replaceOpWithNewOp(mrfOp, mrfFuncOp, allocatedPtr); + + return mlir::success(); + } +}; +struct AwaitFutureOpInterfaceLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + mlir::LogicalResult + matchAndRewrite(RT::AwaitFutureOp afOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + RT::AwaitFutureOp::Adaptor transformed(operands); + OpBuilder::InsertionGuard guard(rewriter); + auto afFuncType = LLVM::LLVMFunctionType::get( + mlir::LLVM::LLVMPointerType::get(getVoidPtrI64Type(rewriter)), + {getVoidPtrI64Type(rewriter)}); + auto afFuncOp = + getOrInsertFuncOpDecl(afOp, "_dfr_await_future", afFuncType, rewriter); + auto afCallOp = rewriter.create(afOp.getLoc(), afFuncOp, + transformed.getOperands()); + Value futVal = rewriter.create( + afOp.getLoc(), + mlir::LLVM::LLVMPointerType::get( + (*getTypeConverter()).convertType(afOp.getResult().getType())), + afCallOp.getResult(0)); + rewriter.replaceOpWithNewOp(afOp, futVal); + return success(); + } +}; +struct CreateAsyncTaskOpInterfaceLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + mlir::LogicalResult + matchAndRewrite(RT::CreateAsyncTaskOp catOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + RT::CreateAsyncTaskOp::Adaptor transformed(operands); + auto catFuncType = + LLVM::LLVMFunctionType::get(getVoidType(), {}, /*isVariadic=*/true); + auto catFuncOp = getOrInsertFuncOpDecl(catOp, "_dfr_create_async_task", + catFuncType, rewriter); + rewriter.replaceOpWithNewOp(catOp, catFuncOp, + transformed.getOperands()); + return success(); + } +}; +struct DeallocateFutureOpInterfaceLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + mlir::LogicalResult + matchAndRewrite(RT::DeallocateFutureOp dfOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + RT::DeallocateFutureOp::Adaptor transformed(operands); + auto dfFuncType = LLVM::LLVMFunctionType::get( + getVoidType(), {getVoidPtrI64Type(rewriter)}); + auto dfFuncOp = getOrInsertFuncOpDecl(dfOp, "_dfr_deallocate_future", + dfFuncType, rewriter); + rewriter.replaceOpWithNewOp(dfOp, dfFuncOp, + transformed.getOperands()); + return success(); + } +}; +struct DeallocateFutureDataOpInterfaceLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + RT::DeallocateFutureDataOp>::ConvertOpToLLVMPattern; + + mlir::LogicalResult + matchAndRewrite(RT::DeallocateFutureDataOp dfdOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + RT::DeallocateFutureDataOp::Adaptor transformed(operands); + auto dfdFuncType = LLVM::LLVMFunctionType::get( + getVoidType(), {getVoidPtrI64Type(rewriter)}); + auto dfdFuncOp = getOrInsertFuncOpDecl(dfdOp, "_dfr_deallocate_future_data", + dfdFuncType, rewriter); + rewriter.replaceOpWithNewOp(dfdOp, dfdFuncOp, + transformed.getOperands()); + return success(); + } +}; +struct BuildReturnPtrPlaceholderOpInterfaceLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + RT::BuildReturnPtrPlaceholderOp>::ConvertOpToLLVMPattern; + + mlir::LogicalResult + matchAndRewrite(RT::BuildReturnPtrPlaceholderOp befOp, + ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + OpBuilder::InsertionGuard guard(rewriter); + + // BuildReturnPtrPlaceholder is a placeholder for generating a memory + // location where a pointer to allocated memory can be written so + // that we can return outputs from task work function. + Value one = rewriter.create( + befOp.getLoc(), + (*getTypeConverter()).convertType(rewriter.getIndexType()), + rewriter.getIntegerAttr( + (*getTypeConverter()).convertType(rewriter.getIndexType()), 1)); + rewriter.replaceOpWithNewOp( + befOp, mlir::LLVM::LLVMPointerType::get(getVoidPtrI64Type(rewriter)), + one, + /*alignment=*/ + rewriter.getIntegerAttr( + (*getTypeConverter()).convertType(rewriter.getIndexType()), 0)); + return success(); + } +}; +struct DerefReturnPtrPlaceholderOpInterfaceLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + RT::DerefReturnPtrPlaceholderOp>::ConvertOpToLLVMPattern; + + mlir::LogicalResult + matchAndRewrite(RT::DerefReturnPtrPlaceholderOp drppOp, + ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + RT::DerefReturnPtrPlaceholderOp::Adaptor transformed(operands); + + // DerefReturnPtrPlaceholder is a placeholder for generating a + // dereference operation for the pointer used to get results from + // task. + rewriter.replaceOpWithNewOp( + drppOp, transformed.getOperands().front()); + return success(); + } +}; +struct DerefWorkFunctionArgumentPtrPlaceholderOpInterfaceLowering + : public ConvertOpToLLVMPattern< + RT::DerefWorkFunctionArgumentPtrPlaceholderOp> { + using ConvertOpToLLVMPattern< + RT::DerefWorkFunctionArgumentPtrPlaceholderOp>::ConvertOpToLLVMPattern; + + mlir::LogicalResult + matchAndRewrite(RT::DerefWorkFunctionArgumentPtrPlaceholderOp dwfappOp, + ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + RT::DerefWorkFunctionArgumentPtrPlaceholderOp::Adaptor transformed( + operands); + OpBuilder::InsertionGuard guard(rewriter); + + // DerefWorkFunctionArgumentPtrPlaceholderOp is a placeholder for + // generating a dereference operation for the pointer used to pass + // arguments to the task. + rewriter.replaceOpWithNewOp( + dwfappOp, transformed.getOperands().front()); + return success(); + } +}; +struct WorkFunctionReturnOpInterfaceLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + RT::WorkFunctionReturnOp>::ConvertOpToLLVMPattern; + + mlir::LogicalResult + matchAndRewrite(RT::WorkFunctionReturnOp wfrOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + RT::WorkFunctionReturnOp::Adaptor transformed(operands); + rewriter.replaceOpWithNewOp( + wfrOp, transformed.getOperands().front(), + transformed.getOperands().back()); + return success(); + } +}; +} // end anonymous namespace +} // namespace zamalang +} // namespace mlir + +void mlir::zamalang::populateRTToLLVMConversionPatterns( + LLVMTypeConverter &converter, RewritePatternSet &patterns) { + // clang-format off + patterns.add< + MakeReadyFutureOpInterfaceLowering, + AwaitFutureOpInterfaceLowering, + BuildReturnPtrPlaceholderOpInterfaceLowering, + DerefReturnPtrPlaceholderOpInterfaceLowering, + DerefWorkFunctionArgumentPtrPlaceholderOpInterfaceLowering, + CreateAsyncTaskOpInterfaceLowering, + DeallocateFutureOpInterfaceLowering, + DeallocateFutureDataOpInterfaceLowering, + WorkFunctionReturnOpInterfaceLowering>(converter); + // clang-format on +} diff --git a/compiler/lib/Dialect/RT/CMakeLists.txt b/compiler/lib/Dialect/RT/CMakeLists.txt index f33061b2d..4f7494893 100644 --- a/compiler/lib/Dialect/RT/CMakeLists.txt +++ b/compiler/lib/Dialect/RT/CMakeLists.txt @@ -1 +1,2 @@ +add_subdirectory(Analysis) add_subdirectory(IR) diff --git a/compiler/lib/Dialect/RT/IR/RTDialect.cpp b/compiler/lib/Dialect/RT/IR/RTDialect.cpp index ef6d6ae19..f14e79afa 100644 --- a/compiler/lib/Dialect/RT/IR/RTDialect.cpp +++ b/compiler/lib/Dialect/RT/IR/RTDialect.cpp @@ -1,3 +1,16 @@ +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/FunctionImplementation.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" + #include "zamalang/Dialect/RT/IR/RTDialect.h" #include "zamalang/Dialect/RT/IR/RTOps.h" #include "zamalang/Dialect/RT/IR/RTTypes.h" @@ -24,7 +37,7 @@ void RTDialect::initialize() { ::mlir::Type RTDialect::parseType(::mlir::DialectAsmParser &parser) const { mlir::Type type; if (parser.parseOptionalKeyword("future").succeeded()) { - generatedTypeParser(this->getContext(), parser, "future", type); + generatedTypeParser(parser, "future", type); return type; } return type; @@ -35,4 +48,4 @@ void RTDialect::printType(::mlir::Type type, if (generatedTypePrinter(type, printer).failed()) { printer.printType(type); } -} \ No newline at end of file +} diff --git a/compiler/lib/Dialect/RT/IR/RTOps.cpp b/compiler/lib/Dialect/RT/IR/RTOps.cpp index e8b59fc5b..cec851945 100644 --- a/compiler/lib/Dialect/RT/IR/RTOps.cpp +++ b/compiler/lib/Dialect/RT/IR/RTOps.cpp @@ -1,3 +1,7 @@ +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Region.h" #include "mlir/IR/TypeUtilities.h" @@ -6,3 +10,17 @@ #define GET_OP_CLASSES #include "zamalang/Dialect/RT/IR/RTOps.cpp.inc" + +using namespace mlir::zamalang::RT; + +void DataflowTaskOp::build( + ::mlir::OpBuilder &builder, ::mlir::OperationState &result, + ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, + ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + result.addOperands(operands); + result.addAttributes(attributes); + Region *reg = result.addRegion(); + Block *body = new Block(); + reg->push_back(body); + result.addTypes(resultTypes); +} diff --git a/compiler/lib/Support/CMakeLists.txt b/compiler/lib/Support/CMakeLists.txt index 098b83b56..220689785 100644 --- a/compiler/lib/Support/CMakeLists.txt +++ b/compiler/lib/Support/CMakeLists.txt @@ -28,6 +28,7 @@ add_mlir_library(ZamalangSupport LowLFHEUnparametrize MLIRLowerableDialectsToLLVM HLFHEDialectAnalysis + RTDialectAnalysis MLIRExecutionEngine ${LLVM_PTHREAD_LIB} diff --git a/compiler/lib/Support/CompilerEngine.cpp b/compiler/lib/Support/CompilerEngine.cpp index 4f0e65c40..322a826b6 100644 --- a/compiler/lib/Support/CompilerEngine.cpp +++ b/compiler/lib/Support/CompilerEngine.cpp @@ -84,6 +84,8 @@ void CompilerEngine::setVerifyDiagnostics(bool v) { this->verifyDiagnostics = v; } +void CompilerEngine::setAutoParallelize(bool v) { this->autoParallelize = v; } + void CompilerEngine::setGenerateClientParameters(bool v) { this->generateClientParameters = v; } @@ -215,6 +217,13 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) { return errorDiag("Tiling of HLFHELinalg operations failed"); } + // Auto parallelization + if (this->autoParallelize && + mlir::zamalang::pipeline::autopar(mlirContext, module, enablePass) + .failed()) { + return StreamStringError("Auto parallelization failed"); + } + if (target == Target::HLFHE) return std::move(res); diff --git a/compiler/lib/Support/Pipeline.cpp b/compiler/lib/Support/Pipeline.cpp index 4a7f774df..3bca8909c 100644 --- a/compiler/lib/Support/Pipeline.cpp +++ b/compiler/lib/Support/Pipeline.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include #include #include @@ -102,6 +103,17 @@ getFHEConstraintsFromHLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module, return ret; } +mlir::LogicalResult autopar(mlir::MLIRContext &context, mlir::ModuleOp &module, + std::function enablePass) { + mlir::PassManager pm(&context); + pipelinePrinting("AutoPar", pm, context); + + addPotentiallyNestedPass( + pm, mlir::zamalang::createBuildDataflowTaskGraphPass(), enablePass); + + return pm.run(module.getOperation()); +} + mlir::LogicalResult tileMarkedHLFHELinalg(mlir::MLIRContext &context, mlir::ModuleOp &module, std::function enablePass) { @@ -190,8 +202,18 @@ lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module, enablePass); addPotentiallyNestedPass(pm, mlir::createSCFBufferizePass(), enablePass); addPotentiallyNestedPass(pm, mlir::createFuncBufferizePass(), enablePass); + addPotentiallyNestedPass( + pm, mlir::zamalang::createBufferizeDataflowTaskOpsPass(), enablePass); addPotentiallyNestedPass(pm, mlir::createFinalizingBufferizePass(), enablePass); + + // Lower Dataflow tasks to DRF + addPotentiallyNestedPass(pm, mlir::zamalang::createFixupDataflowTaskOpsPass(), + enablePass); + addPotentiallyNestedPass(pm, mlir::zamalang::createLowerDataflowTasksPass(), + enablePass); + addPotentiallyNestedPass(pm, mlir::createConvertLinalgToLoopsPass(), + enablePass); addPotentiallyNestedPass(pm, mlir::createLowerToCFGPass(), enablePass); // Convert to MLIR LLVM Dialect diff --git a/compiler/src/main.cpp b/compiler/src/main.cpp index 2e3ff53ee..6e6dc6be5 100644 --- a/compiler/src/main.cpp +++ b/compiler/src/main.cpp @@ -123,6 +123,11 @@ llvm::cl::opt splitInputFile( "chunk independently"), llvm::cl::init(false)); +llvm::cl::opt autoParallelize( + "parallelize", + llvm::cl::desc("Generate (and execute if JIT) parallel code"), + llvm::cl::init(false)); + llvm::cl::opt jitFuncName( "jit-funcname", llvm::cl::desc("Name of the function to execute, default 'main'"), @@ -229,7 +234,7 @@ mlir::LogicalResult processInputBuffer( llvm::Optional overrideMaxEintPrecision, llvm::Optional overrideMaxMANP, bool verifyDiagnostics, llvm::Optional> hlfhelinalgTileSizes, - llvm::raw_ostream &os, + bool autoParallelize, llvm::raw_ostream &os, std::shared_ptr outputLib) { std::shared_ptr ccx = mlir::zamalang::CompilationContext::createShared(); @@ -237,6 +242,7 @@ mlir::LogicalResult processInputBuffer( mlir::zamalang::JitCompilerEngine ce{ccx}; ce.setVerifyDiagnostics(verifyDiagnostics); + ce.setAutoParallelize(autoParallelize); if (cmdline::passes.size() != 0) { ce.setEnablePass([](mlir::Pass *pass) { return std::any_of( @@ -404,7 +410,8 @@ mlir::LogicalResult compilerMain(int argc, char **argv) { std::move(inputBuffer), fileName, cmdline::action, cmdline::jitFuncName, cmdline::jitArgs, cmdline::assumeMaxEintPrecision, cmdline::assumeMaxMANP, - cmdline::verifyDiagnostics, hlfhelinalgTileSizes, os, outputLib); + cmdline::verifyDiagnostics, hlfhelinalgTileSizes, + cmdline::autoParallelize, os, outputLib); }; auto &os = output->os(); auto res = mlir::failure();