mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
feat(compiler): add Dataflow/RT dialect and code generation for dataflow auto parallelization.
This commit is contained in:
@@ -19,7 +19,7 @@ class HLFHE_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<HLFHE_Dialect, mnemonic, traits>;
|
||||
|
||||
// Generates an encrypted zero constant
|
||||
def ZeroEintOp : HLFHE_Op<"zero"> {
|
||||
def ZeroEintOp : HLFHE_Op<"zero", [NoSideEffect]> {
|
||||
let arguments = (ins);
|
||||
let results = (outs EncryptedIntegerType:$out);
|
||||
}
|
||||
|
||||
28
compiler/include/zamalang/Dialect/RT/Analysis/Autopar.h
Normal file
28
compiler/include/zamalang/Dialect/RT/Analysis/Autopar.h
Normal file
@@ -0,0 +1,28 @@
|
||||
#ifndef ZAMALANG_DIALECT_RT_ANALYSIS_AUTOPAR_H
|
||||
#define ZAMALANG_DIALECT_RT_ANALYSIS_AUTOPAR_H
|
||||
|
||||
#include <functional>
|
||||
#include <mlir/Pass/Pass.h>
|
||||
#include <zamalang/Dialect/RT/IR/RTOps.h>
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class LLVMTypeConverter;
|
||||
class BufferizeTypeConverter;
|
||||
class RewritePatternSet;
|
||||
|
||||
namespace zamalang {
|
||||
std::unique_ptr<mlir::Pass>
|
||||
createBuildDataflowTaskGraphPass(bool debug = false);
|
||||
std::unique_ptr<mlir::Pass> createLowerDataflowTasksPass(bool debug = false);
|
||||
std::unique_ptr<mlir::Pass>
|
||||
createBufferizeDataflowTaskOpsPass(bool debug = false);
|
||||
std::unique_ptr<mlir::Pass> createFixupDataflowTaskOpsPass(bool debug = false);
|
||||
void populateRTToLLVMConversionPatterns(mlir::LLVMTypeConverter &converter,
|
||||
mlir::RewritePatternSet &patterns);
|
||||
void populateRTBufferizePatterns(mlir::BufferizeTypeConverter &typeConverter,
|
||||
mlir::RewritePatternSet &patterns);
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
87
compiler/include/zamalang/Dialect/RT/Analysis/Autopar.td
Normal file
87
compiler/include/zamalang/Dialect/RT/Analysis/Autopar.td
Normal file
@@ -0,0 +1,87 @@
|
||||
#ifndef ZAMALANG_DIALECT_RT_ANALYSIS_AUTOPAR
|
||||
#define ZAMALANG_DIALECT_RT_ANALYSIS_AUTOPAR
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def BuildDataflowTaskGraph : Pass<"BuildDataflowTaskGraph", "mlir::ModuleOp"> {
|
||||
let summary =
|
||||
"Identify profitable dataflow tasks and build DataflowTaskGraph.";
|
||||
|
||||
let description = [{
|
||||
This pass builds a dataflow graph out of a HLFHE program.
|
||||
|
||||
In its current incarnation, it considers some heavier weight
|
||||
operations (e.g., HLFHELinalg Dot and Matmult or bootstraps) as
|
||||
candidates for being executed in a discrete task, and then
|
||||
sinks within the task the lighter weight operation that do not
|
||||
increase the graph cut (amount of dependences in or out).
|
||||
|
||||
The output is a program partitioned in RT::DataflowTaskOp that
|
||||
expose task dependences as arguments and results of the
|
||||
DataflowTaskOp.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
func @main(%arg0: tensor<3x4x!HLFHE.eint<2>>, %arg1: tensor<4x2xi3>) -> tensor<3x2x!HLFHE.eint<2>> {
|
||||
%0 = "HLFHELinalg.matmul_eint_int"(%arg0, %arg1) : (tensor<3x4x!HLFHE.eint<2>>, tensor<4x2xi3>) -> tensor<3x2x!HLFHE.eint<2>>
|
||||
return %0 : tensor<3x2x!HLFHE.eint<2>>
|
||||
}
|
||||
```
|
||||
|
||||
Will result in generating a dataflow task for the Matmul operation:
|
||||
|
||||
```mlir
|
||||
func @main(%arg0: tensor<3x4x!HLFHE.eint<2>>, %arg1: tensor<4x2xi3>) -> tensor<3x2x!HLFHE.eint<2>> {
|
||||
%0 = "RT.dataflow_task"(%arg0, %arg1) ( {
|
||||
%1 = "HLFHELinalg.matmul_eint_int"(%arg0, %arg1) : (tensor<3x4x!HLFHE.eint<2>>, tensor<4x2xi3>) -> tensor<3x2x!HLFHE.eint<2>>
|
||||
"RT.dataflow_yield"(%1) : (tensor<3x2x!HLFHE.eint<2>>) -> ()
|
||||
}) : (tensor<3x4x!HLFHE.eint<2>>, tensor<4x2xi3>) -> tensor<3x2x!HLFHE.eint<2>>
|
||||
return %0 : tensor<3x2x!HLFHE.eint<2>>
|
||||
}
|
||||
```
|
||||
}];
|
||||
}
|
||||
|
||||
def BufferizeDataflowTaskOps : Pass<"BufferizeDataflowTaskOps", "mlir::ModuleOp"> {
|
||||
let summary =
|
||||
"Bufferize DataflowTaskOp(s).";
|
||||
|
||||
let description = [{
|
||||
This pass lowers DataflowTaskOp arguments and results from tensors
|
||||
to mlir::memref. It also lowers the arguments of DataflowYieldOp.
|
||||
}];
|
||||
}
|
||||
|
||||
def FixupDataflowTaskOps : Pass<"FixupDataflowTaskOps", "mlir::ModuleOp"> {
|
||||
let summary =
|
||||
"Fix DataflowTaskOp(s) before lowering.";
|
||||
|
||||
let description = [{
|
||||
This pass fixes up code changes that intervene between the
|
||||
BuildDataflowTaskGraph pass and the lowering of the taskgraph to
|
||||
LLVMIR and calls to the DFR runtime system.
|
||||
|
||||
In particular, some operations (e.g., constants, dimension
|
||||
operations, etc.) can be used within the task while only defined
|
||||
outside. In most cases cloning and sinking these operations in the
|
||||
task is the simplest to avoid adding dependences.
|
||||
|
||||
}];
|
||||
}
|
||||
|
||||
def LowerDataflowTasks : Pass<"LowerDataflowTasks", "mlir::ModuleOp"> {
|
||||
let summary =
|
||||
"Outline the body of a DataflowTaskOp into a separate function which will serve as a task work function and lower the task graph to RT.";
|
||||
|
||||
let description = [{
|
||||
This pass lowers a DataflowTaskGraph to the RT dialect, outlining
|
||||
DataflowTaskOp into separate work functions and introducing the
|
||||
necessary operations to communicate and synchronize execution via
|
||||
futures.
|
||||
}];
|
||||
}
|
||||
|
||||
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,6 @@
|
||||
set(LLVM_TARGET_DEFINITIONS Autopar.td)
|
||||
mlir_tablegen(Autopar.h.inc -gen-pass-decls -name Analysis)
|
||||
mlir_tablegen(Autopar.capi.h.inc -gen-pass-capi-header --prefix Analysis)
|
||||
mlir_tablegen(Autopar.capi.cpp.inc -gen-pass-capi-impl --prefix Analysis)
|
||||
add_public_tablegen_target(AutoparPassIncGen)
|
||||
|
||||
@@ -1 +1,2 @@
|
||||
add_subdirectory(Analysis)
|
||||
add_subdirectory(IR)
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/IR/SymbolInterfaces.td"
|
||||
include "mlir/Interfaces/DataLayoutInterfaces.td"
|
||||
|
||||
include "zamalang/Dialect/RT/IR/RTDialect.td"
|
||||
include "zamalang/Dialect/RT/IR/RTTypes.td"
|
||||
@@ -16,9 +18,24 @@ def DataflowTaskOp : RT_Op<"dataflow_task", [SingleBlockImplicitTerminator<"Data
|
||||
|
||||
let regions = (region AnyRegion:$body);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins
|
||||
CArg<"TypeRange", "{}">: $resultTypes,
|
||||
CArg<"ValueRange", "{}">: $operands,
|
||||
CArg<"ArrayRef<NamedAttribute>", "{}">: $attrs)>
|
||||
];
|
||||
let skipDefaultBuilders = 1;
|
||||
|
||||
let summary = "Dataflow task operation";
|
||||
let description = [{
|
||||
`RT.dataflow_task` allows to specify a task that will be concurrently executed when their operands are ready.
|
||||
|
||||
`RT.dataflow_task` allows to specify a task that will be concurrently
|
||||
executed when their operands are ready. Operands are either the
|
||||
results of computation in other `RT.dataflow_task` (dataflow
|
||||
dependences) or obtained from the execution context (immediate
|
||||
operands). Operands are synchronized using futures and, in the case
|
||||
of immediate operands, copied when the task is created. Caution is
|
||||
required when the operand is a pointer as no deep copy will occur.
|
||||
|
||||
Example:
|
||||
|
||||
@@ -43,6 +60,7 @@ func @test(%0 : i64): (i64, i64) {
|
||||
}) : (i64, i64) -> (i64, i64)
|
||||
return %3, %4 : (i64, i64)
|
||||
}
|
||||
```
|
||||
}];
|
||||
}
|
||||
|
||||
@@ -52,7 +70,7 @@ def DataflowYieldOp : RT_Op<"dataflow_yield", [ReturnLike, Terminator]> {
|
||||
let summary = "Dataflow yield operation";
|
||||
let description = [{
|
||||
`RT.dataflow_yield` is a special terminator operation for blocks inside the region
|
||||
in `RT.dataflow_task`. It allows to specify the returns values of a `RT.dataflow_task`.
|
||||
in `RT.dataflow_task`. It allows to specify the return values of a `RT.dataflow_task`.
|
||||
|
||||
Example:
|
||||
|
||||
@@ -64,4 +82,64 @@ Example:
|
||||
}];
|
||||
}
|
||||
|
||||
def MakeReadyFutureOp : RT_Op<"make_ready_future"> {
|
||||
let arguments = (ins AnyType: $input);
|
||||
let results = (outs RT_Future: $output);
|
||||
let summary = "Build a ready future.";
|
||||
let description = [{
|
||||
Data passed to dataflow tasks must be encapsulated in futures,
|
||||
including immediate operands. These must be converted into futures
|
||||
using `RT.make_ready_future`.
|
||||
}];
|
||||
}
|
||||
|
||||
def AwaitFutureOp : RT_Op<"await_future"> {
|
||||
let arguments = (ins RT_Future: $input);
|
||||
let results = (outs AnyType: $output);
|
||||
let summary = "Wait for a future and access its data.";
|
||||
let description = [{
|
||||
The results of a dataflow task are always futures which could be
|
||||
further used as inputs to subsequent tasks. When the result of a task
|
||||
is needed in the outer execution context, the result future needs to
|
||||
be synchronized and its data accessed using `RT.await_future`.
|
||||
}];
|
||||
}
|
||||
|
||||
def CreateAsyncTaskOp : RT_Op<"create_async_task"> {
|
||||
let arguments = (ins SymbolRefAttr:$workfn,
|
||||
Variadic<AnyType>:$list);
|
||||
let results = (outs );
|
||||
let summary = "Create a dataflow task.";
|
||||
}
|
||||
|
||||
def DeallocateFutureOp : RT_Op<"deallocate_future"> {
|
||||
let arguments = (ins RT_Future: $input);
|
||||
let results = (outs );
|
||||
}
|
||||
|
||||
def DeallocateFutureDataOp : RT_Op<"deallocate_future_data"> {
|
||||
let arguments = (ins RT_Future: $input);
|
||||
let results = (outs );
|
||||
}
|
||||
|
||||
def BuildReturnPtrPlaceholderOp : RT_Op<"build_return_ptr_placeholder"> {
|
||||
let arguments = (ins );
|
||||
let results = (outs RT_Pointer: $output);
|
||||
}
|
||||
|
||||
def DerefReturnPtrPlaceholderOp : RT_Op<"deref_return_ptr_placeholder"> {
|
||||
let arguments = (ins RT_Pointer: $input);
|
||||
let results = (outs RT_Future: $output);
|
||||
}
|
||||
|
||||
def DerefWorkFunctionArgumentPtrPlaceholderOp : RT_Op<"deref_work_function_argument_ptr_placeholder"> {
|
||||
let arguments = (ins RT_Pointer: $input);
|
||||
let results = (outs AnyType: $output);
|
||||
}
|
||||
|
||||
def WorkFunctionReturnOp : RT_Op<"work_function_return"> {
|
||||
let arguments = (ins AnyType:$in, AnyType:$out);
|
||||
let results = (outs );
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -46,7 +46,40 @@ def RT_Future : RT_Type<"Future"> {
|
||||
return Type();
|
||||
return get($_ctxt, elementType);
|
||||
}];
|
||||
//let genVerifyDecl = 1;
|
||||
}
|
||||
|
||||
def RT_Pointer : RT_Type<"Pointer"> {
|
||||
let mnemonic = "rtptr";
|
||||
|
||||
let summary = "Pointer to a parameterized element type";
|
||||
|
||||
let description = [{
|
||||
}];
|
||||
|
||||
let parameters = (ins "Type":$elementType);
|
||||
|
||||
let builders = [
|
||||
TypeBuilderWithInferredContext<(ins "Type":$elementType), [{
|
||||
return $_get(elementType.getContext(), elementType);
|
||||
}]>
|
||||
];
|
||||
|
||||
let printer = [{
|
||||
$_printer << "rtptr<";
|
||||
$_printer.printType(getElementType());
|
||||
$_printer << ">";
|
||||
}];
|
||||
|
||||
let parser = [{
|
||||
if ($_parser.parseLess())
|
||||
return Type();
|
||||
Type elementType;
|
||||
if ($_parser.parseType(elementType))
|
||||
return Type();
|
||||
if ($_parser.parseGreater())
|
||||
return Type();
|
||||
return get($_ctxt, elementType);
|
||||
}];
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -122,7 +122,7 @@ public:
|
||||
CompilerEngine(std::shared_ptr<CompilationContext> compilationContext)
|
||||
: overrideMaxEintPrecision(), overrideMaxMANP(),
|
||||
clientParametersFuncName(), verifyDiagnostics(false),
|
||||
generateClientParameters(false),
|
||||
autoParallelize(false), generateClientParameters(false),
|
||||
enablePass([](mlir::Pass *pass) { return true; }),
|
||||
compilationContext(compilationContext) {}
|
||||
|
||||
@@ -146,6 +146,7 @@ public:
|
||||
void setMaxEintPrecision(size_t v);
|
||||
void setMaxMANP(size_t v);
|
||||
void setVerifyDiagnostics(bool v);
|
||||
void setAutoParallelize(bool v);
|
||||
void setGenerateClientParameters(bool v);
|
||||
void setClientParametersFuncName(const llvm::StringRef &name);
|
||||
void setHLFHELinalgTileSizes(llvm::ArrayRef<int64_t> sizes);
|
||||
@@ -158,6 +159,7 @@ protected:
|
||||
llvm::Optional<std::vector<int64_t>> hlfhelinalgTileSizes;
|
||||
|
||||
bool verifyDiagnostics;
|
||||
bool autoParallelize;
|
||||
bool generateClientParameters;
|
||||
std::function<bool(mlir::Pass *)> enablePass;
|
||||
|
||||
|
||||
@@ -12,6 +12,9 @@ namespace mlir {
|
||||
namespace zamalang {
|
||||
namespace pipeline {
|
||||
|
||||
mlir::LogicalResult autopar(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
llvm::Expected<llvm::Optional<mlir::zamalang::V0FHEConstraint>>
|
||||
getFHEConstraintsFromHLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h"
|
||||
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h"
|
||||
#include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h"
|
||||
#include "zamalang/Dialect/RT/IR/RTOps.h"
|
||||
|
||||
namespace {
|
||||
struct HLFHEToMidLFHEPass : public HLFHEToMidLFHEBase<HLFHEToMidLFHEPass> {
|
||||
@@ -92,6 +93,12 @@ void HLFHEToMidLFHEPass::runOnOperation() {
|
||||
converter);
|
||||
mlir::populateFuncOpTypeConversionPattern(patterns, converter);
|
||||
|
||||
// Conversion of RT Dialect Ops
|
||||
patterns.add<mlir::zamalang::GenericTypeConverterPattern<
|
||||
mlir::zamalang::RT::DataflowTaskOp>>(patterns.getContext(), converter);
|
||||
mlir::zamalang::addDynamicallyLegalTypeOp<mlir::zamalang::RT::DataflowTaskOp>(
|
||||
target, converter);
|
||||
|
||||
// Apply conversion
|
||||
if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) {
|
||||
this->signalPassFailure();
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include "zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h"
|
||||
#include "zamalang/Dialect/LowLFHE/IR/LowLFHEOps.h"
|
||||
#include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h"
|
||||
#include "zamalang/Dialect/RT/IR/RTOps.h"
|
||||
#include "zamalang/Support/Constants.h"
|
||||
|
||||
/// LowLFHEUnparametrizeTypeConverter is a type converter that unparametrize
|
||||
@@ -123,6 +124,12 @@ void LowLFHEUnparametrizePass::runOnOperation() {
|
||||
patterns.getContext(), converter);
|
||||
mlir::zamalang::addDynamicallyLegalTypeOp<mlir::CallOp>(target, converter);
|
||||
|
||||
// Conversion of RT Dialect Ops
|
||||
patterns.add<mlir::zamalang::GenericTypeConverterPattern<
|
||||
mlir::zamalang::RT::DataflowTaskOp>>(patterns.getContext(), converter);
|
||||
mlir::zamalang::addDynamicallyLegalTypeOp<mlir::zamalang::RT::DataflowTaskOp>(
|
||||
target, converter);
|
||||
|
||||
// Apply conversion
|
||||
if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) {
|
||||
this->signalPassFailure();
|
||||
|
||||
@@ -19,6 +19,8 @@
|
||||
|
||||
#include "zamalang/Conversion/Passes.h"
|
||||
#include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h"
|
||||
#include "zamalang/Dialect/RT/Analysis/Autopar.h"
|
||||
#include "zamalang/Dialect/RT/IR/RTTypes.h"
|
||||
|
||||
namespace {
|
||||
struct MLIRLowerableDialectsToLLVMPass
|
||||
@@ -52,6 +54,7 @@ void MLIRLowerableDialectsToLLVMPass::runOnOperation() {
|
||||
// Setup the set of the patterns rewriter. At this point we want to
|
||||
// convert the `scf` operations to `std` and `std` operations to `llvm`.
|
||||
mlir::RewritePatternSet patterns(&getContext());
|
||||
mlir::zamalang::populateRTToLLVMConversionPatterns(typeConverter, patterns);
|
||||
mlir::populateStdToLLVMConversionPatterns(typeConverter, patterns);
|
||||
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
|
||||
patterns);
|
||||
@@ -72,10 +75,28 @@ MLIRLowerableDialectsToLLVMPass::convertTypes(mlir::Type type) {
|
||||
type.isa<mlir::zamalang::LowLFHE::LweBootstrapKeyType>() ||
|
||||
type.isa<mlir::zamalang::LowLFHE::ContextType>() ||
|
||||
type.isa<mlir::zamalang::LowLFHE::ForeignPlaintextListType>() ||
|
||||
type.isa<mlir::zamalang::LowLFHE::PlaintextListType>()) {
|
||||
type.isa<mlir::zamalang::LowLFHE::PlaintextListType>() ||
|
||||
type.isa<mlir::zamalang::RT::FutureType>()) {
|
||||
return mlir::LLVM::LLVMPointerType::get(
|
||||
mlir::IntegerType::get(type.getContext(), 64));
|
||||
}
|
||||
if (type.isa<mlir::zamalang::RT::PointerType>()) {
|
||||
mlir::LowerToLLVMOptions options(type.getContext());
|
||||
mlir::LLVMTypeConverter typeConverter(type.getContext(), options);
|
||||
typeConverter.addConversion(convertTypes);
|
||||
typeConverter.addConversion(
|
||||
[&](mlir::zamalang::LowLFHE::PlaintextType type) {
|
||||
return mlir::IntegerType::get(type.getContext(), 64);
|
||||
});
|
||||
typeConverter.addConversion(
|
||||
[&](mlir::zamalang::LowLFHE::CleartextType type) {
|
||||
return mlir::IntegerType::get(type.getContext(), 64);
|
||||
});
|
||||
mlir::Type subtype =
|
||||
type.dyn_cast<mlir::zamalang::RT::PointerType>().getElementType();
|
||||
mlir::Type convertedSubtype = typeConverter.convertType(subtype);
|
||||
return mlir::LLVM::LLVMPointerType::get(convertedSubtype);
|
||||
}
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h"
|
||||
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEOps.h"
|
||||
#include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h"
|
||||
#include "zamalang/Dialect/RT/IR/RTOps.h"
|
||||
#include "zamalang/Support/Constants.h"
|
||||
|
||||
namespace {
|
||||
@@ -300,6 +301,12 @@ void MidLFHEGlobalParametrizationPass::runOnOperation() {
|
||||
mlir::zamalang::populateWithTensorTypeConverterPatterns(patterns, target,
|
||||
converter);
|
||||
|
||||
// Conversion of RT Dialect Ops
|
||||
patterns.add<mlir::zamalang::GenericTypeConverterPattern<
|
||||
mlir::zamalang::RT::DataflowTaskOp>>(patterns.getContext(), converter);
|
||||
mlir::zamalang::addDynamicallyLegalTypeOp<
|
||||
mlir::zamalang::RT::DataflowTaskOp>(target, converter);
|
||||
|
||||
// Apply conversion
|
||||
if (mlir::applyPartialConversion(op, target, std::move(patterns))
|
||||
.failed()) {
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h"
|
||||
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h"
|
||||
#include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h"
|
||||
#include "zamalang/Dialect/RT/IR/RTOps.h"
|
||||
|
||||
namespace {
|
||||
struct MidLFHEToLowLFHEPass
|
||||
@@ -89,6 +90,12 @@ void MidLFHEToLowLFHEPass::runOnOperation() {
|
||||
converter);
|
||||
mlir::populateFuncOpTypeConversionPattern(patterns, converter);
|
||||
|
||||
// Conversion of RT Dialect Ops
|
||||
patterns.add<mlir::zamalang::GenericTypeConverterPattern<
|
||||
mlir::zamalang::RT::DataflowTaskOp>>(patterns.getContext(), converter);
|
||||
mlir::zamalang::addDynamicallyLegalTypeOp<mlir::zamalang::RT::DataflowTaskOp>(
|
||||
target, converter);
|
||||
|
||||
// Apply conversion
|
||||
if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) {
|
||||
this->signalPassFailure();
|
||||
|
||||
119
compiler/lib/Dialect/RT/Analysis/BufferizeDataflowTaskOps.cpp
Normal file
119
compiler/lib/Dialect/RT/Analysis/BufferizeDataflowTaskOps.cpp
Normal file
@@ -0,0 +1,119 @@
|
||||
#include <iostream>
|
||||
|
||||
#include <zamalang/Dialect/RT/Analysis/Autopar.h>
|
||||
#include <zamalang/Dialect/RT/IR/RTDialect.h>
|
||||
#include <zamalang/Dialect/RT/IR/RTOps.h>
|
||||
#include <zamalang/Dialect/RT/IR/RTTypes.h>
|
||||
|
||||
#include <llvm/IR/Instructions.h>
|
||||
#include <mlir/Dialect/MemRef/IR/MemRef.h>
|
||||
#include <mlir/IR/BlockAndValueMapping.h>
|
||||
#include <mlir/IR/Builders.h>
|
||||
#include <mlir/Transforms/Bufferize.h>
|
||||
#include <mlir/Transforms/RegionUtils.h>
|
||||
#include <zamalang/Conversion/Utils/GenericOpTypeConversionPattern.h>
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include <zamalang/Dialect/RT/Analysis/Autopar.h.inc>
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
|
||||
namespace {
|
||||
class BufferizeDataflowYieldOp
|
||||
: public OpConversionPattern<RT::DataflowYieldOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(RT::DataflowYieldOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
RT::DataflowYieldOp::Adaptor transformed(operands);
|
||||
rewriter.replaceOpWithNewOp<RT::DataflowYieldOp>(op, mlir::TypeRange(),
|
||||
transformed.getOperands());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class BufferizeDataflowTaskOp : public OpConversionPattern<RT::DataflowTaskOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(RT::DataflowTaskOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
RT::DataflowTaskOp::Adaptor transformed(operands);
|
||||
mlir::OpBuilder::InsertionGuard guard(rewriter);
|
||||
|
||||
SmallVector<Type> newResults;
|
||||
(void)getTypeConverter()->convertTypes(op.getResultTypes(), newResults);
|
||||
auto newop = rewriter.create<RT::DataflowTaskOp>(op.getLoc(), newResults,
|
||||
transformed.getOperands());
|
||||
// We cannot clone here as cloned ops must be legalized (so this
|
||||
// would break on the YieldOp). Instead use mergeBlocks which
|
||||
// moves the ops instead of cloning.
|
||||
rewriter.mergeBlocks(op.getBody(), newop.getBody(),
|
||||
newop.getBody()->getArguments());
|
||||
// Because of previous bufferization there are buffer cast ops
|
||||
// that have been generated for the previously tensor results of
|
||||
// some tasks. These cannot just be replaced directly as the
|
||||
// task's results would still be live.
|
||||
for (auto res : llvm::enumerate(op.getResults())) {
|
||||
// If this result is getting bufferized ...
|
||||
if (res.value().getType() !=
|
||||
getTypeConverter()->convertType(res.value().getType())) {
|
||||
for (auto &use : llvm::make_early_inc_range(res.value().getUses())) {
|
||||
// ... and its uses are in `BufferCastOp`s, then we
|
||||
// replace further uses of the buffer cast.
|
||||
if (isa<mlir::memref::BufferCastOp>(use.getOwner())) {
|
||||
rewriter.replaceOp(use.getOwner(), {newop.getResult(res.index())});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
rewriter.replaceOp(op, {newop.getResults()});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void populateRTBufferizePatterns(BufferizeTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<BufferizeDataflowYieldOp, BufferizeDataflowTaskOp>(
|
||||
typeConverter, patterns.getContext());
|
||||
}
|
||||
|
||||
namespace {
|
||||
// For documentation see Autopar.td
|
||||
struct BufferizeDataflowTaskOpsPass
|
||||
: public BufferizeDataflowTaskOpsBase<BufferizeDataflowTaskOpsPass> {
|
||||
|
||||
void runOnOperation() override {
|
||||
auto module = getOperation();
|
||||
auto *context = &getContext();
|
||||
BufferizeTypeConverter typeConverter;
|
||||
RewritePatternSet patterns(context);
|
||||
ConversionTarget target(*context);
|
||||
|
||||
populateRTBufferizePatterns(typeConverter, patterns);
|
||||
|
||||
// Forbid all RT ops that still use/return tensors
|
||||
target.addDynamicallyLegalDialect<RT::RTDialect>(
|
||||
[&](Operation *op) { return typeConverter.isLegal(op); });
|
||||
|
||||
if (failed(applyPartialConversion(module, target, std::move(patterns))))
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
BufferizeDataflowTaskOpsPass(bool debug) : debug(debug){};
|
||||
|
||||
protected:
|
||||
bool debug;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
std::unique_ptr<mlir::Pass> createBufferizeDataflowTaskOpsPass(bool debug) {
|
||||
return std::make_unique<BufferizeDataflowTaskOpsPass>(debug);
|
||||
}
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
266
compiler/lib/Dialect/RT/Analysis/BuildDataflowTaskGraph.cpp
Normal file
266
compiler/lib/Dialect/RT/Analysis/BuildDataflowTaskGraph.cpp
Normal file
@@ -0,0 +1,266 @@
|
||||
#include <iostream>
|
||||
|
||||
#include <mlir/IR/BuiltinOps.h>
|
||||
#include <zamalang/Dialect/HLFHE/IR/HLFHEDialect.h>
|
||||
#include <zamalang/Dialect/HLFHE/IR/HLFHEOps.h>
|
||||
#include <zamalang/Dialect/HLFHE/IR/HLFHETypes.h>
|
||||
#include <zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.h>
|
||||
#include <zamalang/Dialect/RT/Analysis/Autopar.h>
|
||||
#include <zamalang/Dialect/RT/IR/RTDialect.h>
|
||||
#include <zamalang/Dialect/RT/IR/RTOps.h>
|
||||
#include <zamalang/Dialect/RT/IR/RTTypes.h>
|
||||
#include <zamalang/Support/Constants.h>
|
||||
#include <zamalang/Support/math.h>
|
||||
|
||||
#include <mlir/Dialect/MemRef/IR/MemRef.h>
|
||||
#include <mlir/Dialect/StandardOps/IR/Ops.h>
|
||||
#include <mlir/IR/Attributes.h>
|
||||
#include <mlir/IR/BlockAndValueMapping.h>
|
||||
#include <mlir/IR/Builders.h>
|
||||
#include <mlir/IR/BuiltinAttributes.h>
|
||||
#include <mlir/IR/PatternMatch.h>
|
||||
#include <mlir/Support/LLVM.h>
|
||||
#include <mlir/Support/LogicalResult.h>
|
||||
#include <mlir/Transforms/DialectConversion.h>
|
||||
#include <mlir/Transforms/GreedyPatternRewriteDriver.h>
|
||||
#include <mlir/Transforms/Passes.h>
|
||||
#include <mlir/Transforms/RegionUtils.h>
|
||||
#include <mlir/Transforms/Utils.h>
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include <zamalang/Dialect/RT/Analysis/Autopar.h.inc>
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
|
||||
namespace {
|
||||
|
||||
// TODO: adjust these two functions based on cost model
|
||||
static bool isCandidateForTask(Operation *op) {
|
||||
return isa<HLFHE::AddEintIntOp, HLFHE::AddEintOp, HLFHE::SubIntEintOp,
|
||||
HLFHE::MulEintIntOp, HLFHE::ApplyLookupTableEintOp,
|
||||
HLFHELinalg::MatMulIntEintOp, HLFHELinalg::MatMulEintIntOp,
|
||||
HLFHELinalg::AddEintIntOp, HLFHELinalg::AddEintOp,
|
||||
HLFHELinalg::SubIntEintOp, HLFHELinalg::NegEintOp,
|
||||
HLFHELinalg::MulEintIntOp, HLFHELinalg::ApplyLookupTableEintOp,
|
||||
HLFHELinalg::ApplyMultiLookupTableEintOp, HLFHELinalg::Dot>(op);
|
||||
}
|
||||
|
||||
// Identify operations that are beneficial to sink into tasks. These
|
||||
// operations must not have side-effects and not be `isCandidateForTask`
|
||||
static bool isSinkingBeneficiary(Operation *op) {
|
||||
return isa<HLFHE::ZeroEintOp, arith::ConstantOp, memref::DimOp, SelectOp,
|
||||
mlir::arith::CmpIOp>(op);
|
||||
}
|
||||
|
||||
static bool
|
||||
extractBeneficiaryOps(Operation *op, SetVector<Value> existingDependencies,
|
||||
SetVector<Operation *> &beneficiaryOps,
|
||||
llvm::SmallPtrSetImpl<Value> &availableValues) {
|
||||
if (beneficiaryOps.count(op))
|
||||
return true;
|
||||
|
||||
if (!isSinkingBeneficiary(op))
|
||||
return false;
|
||||
|
||||
for (Value operand : op->getOperands()) {
|
||||
// It is already visible in the kernel, keep going.
|
||||
if (availableValues.count(operand))
|
||||
continue;
|
||||
// Else check whether it can be made available via sinking or already is a
|
||||
// dependency.
|
||||
Operation *definingOp = operand.getDefiningOp();
|
||||
if ((!definingOp ||
|
||||
!extractBeneficiaryOps(definingOp, existingDependencies,
|
||||
beneficiaryOps, availableValues)) &&
|
||||
!existingDependencies.count(operand))
|
||||
return false;
|
||||
}
|
||||
// We will sink the operation, mark its results as now available.
|
||||
beneficiaryOps.insert(op);
|
||||
for (Value result : op->getResults())
|
||||
availableValues.insert(result);
|
||||
return true;
|
||||
}
|
||||
|
||||
LogicalResult sinkOperationsIntoDFTask(RT::DataflowTaskOp taskOp) {
|
||||
Region &taskOpBody = taskOp.body();
|
||||
|
||||
// Identify uses from values defined outside of the scope.
|
||||
SetVector<Value> sinkCandidates;
|
||||
getUsedValuesDefinedAbove(taskOpBody, sinkCandidates);
|
||||
|
||||
SetVector<Operation *> toBeSunk;
|
||||
llvm::SmallPtrSet<Value, 4> availableValues;
|
||||
for (Value operand : sinkCandidates) {
|
||||
Operation *operandOp = operand.getDefiningOp();
|
||||
if (!operandOp)
|
||||
continue;
|
||||
extractBeneficiaryOps(operandOp, sinkCandidates, toBeSunk, availableValues);
|
||||
}
|
||||
|
||||
// Insert operations so that the defs get cloned before uses.
|
||||
BlockAndValueMapping map;
|
||||
OpBuilder builder(taskOpBody);
|
||||
for (Operation *op : toBeSunk) {
|
||||
OpBuilder::InsertionGuard guard(builder);
|
||||
Operation *clonedOp = builder.clone(*op, map);
|
||||
for (auto pair : llvm::zip(op->getResults(), clonedOp->getResults()))
|
||||
replaceAllUsesInRegionWith(std::get<0>(pair), std::get<1>(pair),
|
||||
taskOp.body());
|
||||
// Once this is sunk, remove all operands of the DFT covered by this
|
||||
for (auto result : op->getResults())
|
||||
for (auto operand : llvm::enumerate(taskOp.getOperands()))
|
||||
if (operand.value() == result) {
|
||||
taskOp->eraseOperand(operand.index());
|
||||
// Once removed, we assume there are no duplicates
|
||||
break;
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
// For documentation see Autopar.td
|
||||
struct BuildDataflowTaskGraphPass
|
||||
: public BuildDataflowTaskGraphBase<BuildDataflowTaskGraphPass> {
|
||||
|
||||
void runOnOperation() override {
|
||||
auto module = getOperation();
|
||||
|
||||
module.walk([&](mlir::FuncOp func) {
|
||||
if (!func->getAttr("_dfr_work_function_attribute"))
|
||||
func.walk(
|
||||
[&](mlir::Operation *childOp) { this->processOperation(childOp); });
|
||||
|
||||
// Perform simplifications, in particular DCE here in case some
|
||||
// of the operations sunk in tasks are no longer needed in the
|
||||
// main function. If the function fails it only means that
|
||||
// nothing was simplified. Doing this here - rather than later
|
||||
// in the compilation pipeline - allows to take advantage of
|
||||
// higher level semantics which we can attach to operations
|
||||
// (e.g., NoSideEffect on HLFHE::ZeroEintOp).
|
||||
IRRewriter rewriter(func->getContext());
|
||||
(void)mlir::simplifyRegions(rewriter, func->getRegions());
|
||||
});
|
||||
}
|
||||
BuildDataflowTaskGraphPass(bool debug) : debug(debug){};
|
||||
|
||||
protected:
|
||||
void processOperation(mlir::Operation *op) {
|
||||
if (isCandidateForTask(op)) {
|
||||
BlockAndValueMapping map;
|
||||
Region &opBody = getOperation().body();
|
||||
OpBuilder builder(opBody);
|
||||
|
||||
// Create a DFTask for this operation
|
||||
builder.setInsertionPointAfter(op);
|
||||
auto dftop = builder.create<RT::DataflowTaskOp>(
|
||||
op->getLoc(), op->getResultTypes(), op->getOperands());
|
||||
// Add the operation to the task
|
||||
OpBuilder tbbuilder(dftop.body());
|
||||
Operation *clonedOp = tbbuilder.clone(*op, map);
|
||||
// Add sinkable operations to the task
|
||||
assert(!failed(sinkOperationsIntoDFTask(dftop)) &&
|
||||
"Failing to sink operations into DFT");
|
||||
|
||||
// Add terminator
|
||||
tbbuilder.create<RT::DataflowYieldOp>(dftop.getLoc(), mlir::TypeRange(),
|
||||
op->getResults());
|
||||
// Replace the uses of defined values
|
||||
for (auto pair : llvm::zip(op->getResults(), clonedOp->getResults()))
|
||||
replaceAllUsesInRegionWith(std::get<0>(pair), std::get<1>(pair),
|
||||
dftop.body());
|
||||
// Replace uses of the values defined by the task
|
||||
for (auto pair : llvm::zip(op->getResults(), dftop->getResults()))
|
||||
replaceAllUsesInRegionWith(std::get<0>(pair), std::get<1>(pair),
|
||||
opBody);
|
||||
// Once uses are re-targeted to the task, delete the operation
|
||||
op->erase();
|
||||
}
|
||||
}
|
||||
|
||||
bool debug;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
std::unique_ptr<mlir::Pass> createBuildDataflowTaskGraphPass(bool debug) {
|
||||
return std::make_unique<BuildDataflowTaskGraphPass>(debug);
|
||||
}
|
||||
|
||||
namespace {
|
||||
// Marker to avoid infinite recursion of the rewriting pattern
|
||||
static const mlir::StringLiteral kTransformMarker =
|
||||
"_internal_RT_FixDataflowTaskOpInputsPattern_marker__";
|
||||
|
||||
class FixDataflowTaskOpInputsPattern
|
||||
: public mlir::OpRewritePattern<RT::DataflowTaskOp> {
|
||||
public:
|
||||
FixDataflowTaskOpInputsPattern(mlir::MLIRContext *context)
|
||||
: mlir::OpRewritePattern<RT::DataflowTaskOp>(
|
||||
context, ::mlir::zamalang::DEFAULT_PATTERN_BENEFIT) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(RT::DataflowTaskOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
mlir::OpBuilder::InsertionGuard guard(rewriter);
|
||||
|
||||
if (op->hasAttr(kTransformMarker))
|
||||
return failure();
|
||||
|
||||
// Identify which values need to be passed as dependences to the
|
||||
// task - this is very conservative and will add constants, index
|
||||
// operations, etc. A simplification will occur later.
|
||||
SetVector<Value> deps;
|
||||
getUsedValuesDefinedAbove(op.body(), deps);
|
||||
auto newop = rewriter.create<RT::DataflowTaskOp>(
|
||||
op.getLoc(), op.getResultTypes(), deps.getArrayRef());
|
||||
rewriter.mergeBlocks(op.getBody(), newop.getBody(),
|
||||
newop.getBody()->getArguments());
|
||||
rewriter.replaceOp(op, {newop.getResults()});
|
||||
|
||||
// Mark this as processed to prevent infinite loop
|
||||
newop.getOperation()->setAttr(kTransformMarker, rewriter.getUnitAttr());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// For documentation see Autopar.td
|
||||
struct FixupDataflowTaskOpsPass
|
||||
: public FixupDataflowTaskOpsBase<FixupDataflowTaskOpsPass> {
|
||||
|
||||
void runOnOperation() override {
|
||||
auto module = getOperation();
|
||||
auto *context = &getContext();
|
||||
|
||||
RewritePatternSet patterns(context);
|
||||
patterns.add<FixDataflowTaskOpInputsPattern>(context);
|
||||
|
||||
if (mlir::applyPatternsAndFoldGreedily(module, std::move(patterns))
|
||||
.failed())
|
||||
signalPassFailure();
|
||||
|
||||
// Clear mark and sink any newly created constants or indexing
|
||||
// operations, etc. to reduce the number of input dependences to
|
||||
// the task
|
||||
module->walk([](RT::DataflowTaskOp op) {
|
||||
op.getOperation()->removeAttr(kTransformMarker);
|
||||
assert(!failed(sinkOperationsIntoDFTask(op)) &&
|
||||
"Failing to sink operations into DFT");
|
||||
});
|
||||
}
|
||||
|
||||
FixupDataflowTaskOpsPass(bool debug) : debug(debug){};
|
||||
|
||||
protected:
|
||||
bool debug;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
std::unique_ptr<mlir::Pass> createFixupDataflowTaskOpsPass(bool debug) {
|
||||
return std::make_unique<FixupDataflowTaskOpsPass>(debug);
|
||||
}
|
||||
|
||||
} // end namespace zamalang
|
||||
} // end namespace mlir
|
||||
18
compiler/lib/Dialect/RT/Analysis/CMakeLists.txt
Normal file
18
compiler/lib/Dialect/RT/Analysis/CMakeLists.txt
Normal file
@@ -0,0 +1,18 @@
|
||||
add_mlir_library(RTDialectAnalysis
|
||||
BufferizeDataflowTaskOps.cpp
|
||||
BuildDataflowTaskGraph.cpp
|
||||
LowerDataflowTasksToRT.cpp
|
||||
LowerRTToLLVMDFRCallsConversionPatterns.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/zamalang/Dialect/RT
|
||||
|
||||
DEPENDS
|
||||
RTDialect
|
||||
AutoparPassIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
RTDialect)
|
||||
|
||||
target_link_libraries(RTDialectAnalysis PUBLIC MLIRIR)
|
||||
337
compiler/lib/Dialect/RT/Analysis/LowerDataflowTasksToRT.cpp
Normal file
337
compiler/lib/Dialect/RT/Analysis/LowerDataflowTasksToRT.cpp
Normal file
@@ -0,0 +1,337 @@
|
||||
#include <iostream>
|
||||
|
||||
#include <mlir/IR/BuiltinOps.h>
|
||||
#include <zamalang/Dialect/HLFHE/IR/HLFHEDialect.h>
|
||||
#include <zamalang/Dialect/HLFHE/IR/HLFHEOps.h>
|
||||
#include <zamalang/Dialect/HLFHE/IR/HLFHETypes.h>
|
||||
#include <zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h>
|
||||
#include <zamalang/Dialect/LowLFHE/IR/LowLFHEOps.h>
|
||||
#include <zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h>
|
||||
#include <zamalang/Dialect/RT/Analysis/Autopar.h>
|
||||
#include <zamalang/Dialect/RT/IR/RTDialect.h>
|
||||
#include <zamalang/Dialect/RT/IR/RTOps.h>
|
||||
#include <zamalang/Dialect/RT/IR/RTTypes.h>
|
||||
#include <zamalang/Support/math.h>
|
||||
|
||||
#include <llvm/IR/Instructions.h>
|
||||
#include <mlir/Analysis/DataFlowAnalysis.h>
|
||||
#include <mlir/Conversion/LLVMCommon/ConversionTarget.h>
|
||||
#include <mlir/Conversion/LLVMCommon/Pattern.h>
|
||||
#include <mlir/Conversion/LLVMCommon/VectorPattern.h>
|
||||
#include <mlir/Dialect/LLVMIR/FunctionCallUtils.h>
|
||||
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
|
||||
#include <mlir/Dialect/MemRef/IR/MemRef.h>
|
||||
#include <mlir/Dialect/StandardOps/IR/Ops.h>
|
||||
#include <mlir/Dialect/StandardOps/Transforms/FuncConversions.h>
|
||||
#include <mlir/Dialect/StandardOps/Transforms/Passes.h>
|
||||
#include <mlir/IR/Attributes.h>
|
||||
#include <mlir/IR/BlockAndValueMapping.h>
|
||||
#include <mlir/IR/Builders.h>
|
||||
#include <mlir/IR/BuiltinAttributes.h>
|
||||
#include <mlir/IR/SymbolTable.h>
|
||||
#include <mlir/Pass/PassManager.h>
|
||||
#include <mlir/Support/LLVM.h>
|
||||
#include <mlir/Support/LogicalResult.h>
|
||||
#include <mlir/Transforms/Bufferize.h>
|
||||
#include <mlir/Transforms/DialectConversion.h>
|
||||
#include <mlir/Transforms/Passes.h>
|
||||
#include <mlir/Transforms/RegionUtils.h>
|
||||
#include <mlir/Transforms/Utils.h>
|
||||
#include <zamalang/Conversion/Utils/GenericOpTypeConversionPattern.h>
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include <zamalang/Dialect/RT/Analysis/Autopar.h.inc>
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
|
||||
namespace {
|
||||
|
||||
static FuncOp outlineWorkFunction(RT::DataflowTaskOp DFTOp,
|
||||
StringRef workFunctionName) {
|
||||
Location loc = DFTOp.getLoc();
|
||||
OpBuilder builder(DFTOp.getContext());
|
||||
Region &DFTOpBody = DFTOp.body();
|
||||
OpBuilder::InsertionGuard guard(builder);
|
||||
|
||||
// Instead of outlining with the same operands/results, we pass all
|
||||
// results as operands as well. For now we preserve the results'
|
||||
// types, which will be changed to use an indirection when lowering.
|
||||
SmallVector<Type, 4> operandTypes;
|
||||
operandTypes.reserve(DFTOp.getNumOperands() + DFTOp.getNumResults());
|
||||
for (Value operand : DFTOp.getOperands())
|
||||
operandTypes.push_back(RT::PointerType::get(operand.getType()));
|
||||
for (Value res : DFTOp.getResults())
|
||||
operandTypes.push_back(RT::PointerType::get(res.getType()));
|
||||
|
||||
FunctionType type = FunctionType::get(DFTOp.getContext(), operandTypes, {});
|
||||
auto outlinedFunc = builder.create<FuncOp>(loc, workFunctionName, type);
|
||||
outlinedFunc->setAttr("_dfr_work_function_attribute", builder.getUnitAttr());
|
||||
Region &outlinedFuncBody = outlinedFunc.body();
|
||||
Block *outlinedEntryBlock = new Block;
|
||||
outlinedEntryBlock->addArguments(type.getInputs());
|
||||
outlinedFuncBody.push_back(outlinedEntryBlock);
|
||||
|
||||
BlockAndValueMapping map;
|
||||
Block &entryBlock = outlinedFuncBody.front();
|
||||
builder.setInsertionPointToStart(&entryBlock);
|
||||
for (auto operand : llvm::enumerate(DFTOp.getOperands())) {
|
||||
// Add deref of arguments and remap to operands in the body
|
||||
auto derefdop =
|
||||
builder.create<RT::DerefWorkFunctionArgumentPtrPlaceholderOp>(
|
||||
DFTOp.getLoc(), operand.value().getType(),
|
||||
entryBlock.getArgument(operand.index()));
|
||||
map.map(operand.value(), derefdop->getResult(0));
|
||||
}
|
||||
DFTOpBody.cloneInto(&outlinedFuncBody, map);
|
||||
|
||||
Block &DFTOpEntry = DFTOpBody.front();
|
||||
Block *clonedDFTOpEntry = map.lookup(&DFTOpEntry);
|
||||
builder.setInsertionPointToEnd(&entryBlock);
|
||||
builder.create<BranchOp>(loc, clonedDFTOpEntry);
|
||||
|
||||
// TODO: we use a WorkFunctionReturnOp to tie return to the
|
||||
// corresponding argument. This can be lowered to a copy/deref for
|
||||
// shared memory and pointers, but needs to be handled for
|
||||
// distributed memory.
|
||||
outlinedFunc.walk([&](RT::DataflowYieldOp op) {
|
||||
OpBuilder replacer(op);
|
||||
int output_offset = DFTOp.getNumOperands();
|
||||
for (auto ret : llvm::enumerate(op.getOperands()))
|
||||
replacer.create<RT::WorkFunctionReturnOp>(
|
||||
op.getLoc(), ret.value(),
|
||||
outlinedFunc.getArgument(ret.index() + output_offset));
|
||||
replacer.create<ReturnOp>(op.getLoc());
|
||||
op.erase();
|
||||
});
|
||||
return outlinedFunc;
|
||||
}
|
||||
|
||||
static void replaceAllUsesInDFTsInRegionWith(Value orig, Value replacement,
|
||||
Region ®ion) {
|
||||
for (auto &use : llvm::make_early_inc_range(orig.getUses())) {
|
||||
if (isa<RT::DataflowTaskOp>(use.getOwner()) &&
|
||||
region.isAncestor(use.getOwner()->getParentRegion()))
|
||||
use.set(replacement);
|
||||
}
|
||||
}
|
||||
static void replaceAllUsesNotInDFTsInRegionWith(Value orig, Value replacement,
|
||||
Region ®ion) {
|
||||
for (auto &use : llvm::make_early_inc_range(orig.getUses())) {
|
||||
if (!isa<RT::DataflowTaskOp>(use.getOwner()) &&
|
||||
use.getOwner()->getParentOfType<RT::DataflowTaskOp>() == nullptr &&
|
||||
region.isAncestor(use.getOwner()->getParentRegion()))
|
||||
use.set(replacement);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Fix type sizes. For now we're using some default values.
|
||||
static mlir::Value getSizeInBytes(Value val, Location loc, OpBuilder builder) {
|
||||
DataLayout dataLayout = DataLayout::closest(val.getDefiningOp());
|
||||
Type type = (val.getType().isa<RT::FutureType>())
|
||||
? val.getType().dyn_cast<RT::FutureType>().getElementType()
|
||||
: val.getType();
|
||||
|
||||
// In the case of memref, we need to determine how much space
|
||||
// (conservatively) we need to store the memref itself. Overshooting
|
||||
// by a few bytes should not be an issue, so the main thing is to
|
||||
// properly account for the rank.
|
||||
if (type.isa<mlir::MemRefType>()) {
|
||||
// Space for the allocated and aligned pointers, and offset
|
||||
Value ptrs_offset =
|
||||
builder.create<arith::ConstantOp>(loc, builder.getI64IntegerAttr(24));
|
||||
// For the sizes and shapes arrays, we need 2*8 = 16 times the rank in bytes
|
||||
Value multiplier =
|
||||
builder.create<arith::ConstantOp>(loc, builder.getI64IntegerAttr(16));
|
||||
unsigned _rank = type.dyn_cast<mlir::MemRefType>().getRank();
|
||||
Value rank = builder.create<arith::ConstantOp>(
|
||||
loc, builder.getI64IntegerAttr(_rank));
|
||||
Value sizes_shapes = builder.create<LLVM::MulOp>(loc, rank, multiplier);
|
||||
Value result = builder.create<LLVM::AddOp>(loc, ptrs_offset, sizes_shapes);
|
||||
return result;
|
||||
}
|
||||
|
||||
// Unranked memrefs should be lowered to just pointer + size, so we need 16
|
||||
// bytes.
|
||||
if (type.isa<mlir::UnrankedMemRefType>())
|
||||
return builder.create<arith::ConstantOp>(loc,
|
||||
builder.getI64IntegerAttr(16));
|
||||
|
||||
// FHE types are converted to pointers, so we take their size as 8
|
||||
// bytes until we can get the actual size of the actual types.
|
||||
if (type.isa<mlir::zamalang::LowLFHE::ContextType>() ||
|
||||
type.isa<mlir::zamalang::LowLFHE::LweCiphertextType>() ||
|
||||
type.isa<mlir::zamalang::LowLFHE::GlweCiphertextType>() ||
|
||||
type.isa<mlir::zamalang::LowLFHE::LweKeySwitchKeyType>() ||
|
||||
type.isa<mlir::zamalang::LowLFHE::LweBootstrapKeyType>() ||
|
||||
type.isa<mlir::zamalang::LowLFHE::ForeignPlaintextListType>() ||
|
||||
type.isa<mlir::zamalang::LowLFHE::PlaintextListType>())
|
||||
return builder.create<arith::ConstantOp>(loc, builder.getI64IntegerAttr(8));
|
||||
|
||||
// For all other types, get type size.
|
||||
return builder.create<arith::ConstantOp>(
|
||||
loc, builder.getI64IntegerAttr(dataLayout.getTypeSize(type)));
|
||||
}
|
||||
|
||||
static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp, FuncOp workFunction) {
|
||||
DataLayout dataLayout = DataLayout::closest(DFTOp);
|
||||
Region &opBody = DFTOp->getParentOfType<FuncOp>().body();
|
||||
BlockAndValueMapping map;
|
||||
OpBuilder builder(DFTOp);
|
||||
|
||||
// First identify DFT operands that are not futures and are not
|
||||
// defined by another DFT. These need to be made into futures and
|
||||
// propagated to all other DFTs. We can allow PRE to eliminate the
|
||||
// previous definitions if there are no non-future type uses.
|
||||
builder.setInsertionPoint(DFTOp);
|
||||
for (Value val : DFTOp.getOperands()) {
|
||||
if (!val.getType().isa<RT::FutureType>()) {
|
||||
Type futType = RT::FutureType::get(val.getType());
|
||||
auto mrf =
|
||||
builder.create<RT::MakeReadyFutureOp>(DFTOp.getLoc(), futType, val);
|
||||
map.map(mrf->getResult(0), val);
|
||||
replaceAllUsesInDFTsInRegionWith(val, mrf->getResult(0), opBody);
|
||||
}
|
||||
}
|
||||
|
||||
// Second generate a CreateAsyncTaskOp that will replace the
|
||||
// DataflowTaskOp. This also includes the necessary handling of
|
||||
// operands and results (conversion to/from futures and propagation).
|
||||
SmallVector<Value, 4> catOperands;
|
||||
int size = 3 + DFTOp.getNumResults() * 2 + DFTOp.getNumOperands() * 2;
|
||||
catOperands.reserve(size);
|
||||
auto fnptr = builder.create<mlir::ConstantOp>(
|
||||
DFTOp.getLoc(), workFunction.getType(),
|
||||
SymbolRefAttr::get(builder.getContext(), workFunction.getName()));
|
||||
auto numIns = builder.create<arith::ConstantOp>(
|
||||
DFTOp.getLoc(), builder.getI64IntegerAttr(DFTOp.getNumOperands()));
|
||||
auto numOuts = builder.create<arith::ConstantOp>(
|
||||
DFTOp.getLoc(), builder.getI64IntegerAttr(DFTOp.getNumResults()));
|
||||
catOperands.push_back(fnptr.getResult());
|
||||
catOperands.push_back(numIns.getResult());
|
||||
catOperands.push_back(numOuts.getResult());
|
||||
for (auto operand : DFTOp.getOperands()) {
|
||||
catOperands.push_back(operand);
|
||||
catOperands.push_back(getSizeInBytes(operand, DFTOp.getLoc(), builder));
|
||||
}
|
||||
|
||||
// We need to adjust the results for the CreateAsyncTaskOp which
|
||||
// are the work function's returns through pointers passed as
|
||||
// parameters. As this is not supported within MLIR - and mostly
|
||||
// unsupported even in the LLVMIR Dialect - this needs to use two
|
||||
// placeholders for each output, before and after the
|
||||
// CreateAsyncTaskOp.
|
||||
for (auto result : DFTOp.getResults()) {
|
||||
Type futType = RT::PointerType::get(RT::FutureType::get(result.getType()));
|
||||
auto brpp = builder.create<RT::BuildReturnPtrPlaceholderOp>(DFTOp.getLoc(),
|
||||
futType);
|
||||
map.map(result, brpp->getResult(0));
|
||||
catOperands.push_back(brpp->getResult(0));
|
||||
catOperands.push_back(getSizeInBytes(result, DFTOp.getLoc(), builder));
|
||||
}
|
||||
builder.create<RT::CreateAsyncTaskOp>(
|
||||
DFTOp.getLoc(),
|
||||
SymbolRefAttr::get(builder.getContext(), workFunction.getName()),
|
||||
catOperands);
|
||||
|
||||
// Third identify results of this DFT that are not used *only* in
|
||||
// other DFTs as those will need to be waited on explicitly.
|
||||
// We also create the DerefReturnPtrPlaceholderOp after the
|
||||
// CreateAsyncTaskOp. These also need propagating.
|
||||
for (auto result : DFTOp.getResults()) {
|
||||
Type futType = RT::FutureType::get(result.getType());
|
||||
Value futptr = map.lookupOrNull(result);
|
||||
assert(futptr);
|
||||
auto drpp = builder.create<RT::DerefReturnPtrPlaceholderOp>(
|
||||
DFTOp.getLoc(), futType, futptr);
|
||||
replaceAllUsesInDFTsInRegionWith(result, drpp->getResult(0), opBody);
|
||||
|
||||
for (auto &use : llvm::make_early_inc_range(result.getUses())) {
|
||||
if (!isa<RT::DataflowTaskOp>(use.getOwner()) &&
|
||||
use.getOwner()->getParentOfType<RT::DataflowTaskOp>() == nullptr) {
|
||||
// Wait for this future
|
||||
// TODO: the wait function should ideally
|
||||
// be issued as late as possible, but need to identify which
|
||||
// use comes first.
|
||||
auto af = builder.create<RT::AwaitFutureOp>(
|
||||
DFTOp.getLoc(), result.getType(), drpp.getResult());
|
||||
replaceAllUsesNotInDFTsInRegionWith(result, af->getResult(0), opBody);
|
||||
// We only need to to this once, propagation will hit all
|
||||
// other uses
|
||||
break;
|
||||
}
|
||||
}
|
||||
// All leftover uses (i.e. those within DFTs should use the future)
|
||||
replaceAllUsesInRegionWith(result, futptr, opBody);
|
||||
}
|
||||
|
||||
// Finally erase the DFT.
|
||||
DFTOp.erase();
|
||||
}
|
||||
|
||||
// For documentation see Autopar.td
|
||||
struct LowerDataflowTasksPass
|
||||
: public LowerDataflowTasksBase<LowerDataflowTasksPass> {
|
||||
|
||||
void runOnOperation() override {
|
||||
auto module = getOperation();
|
||||
|
||||
module.walk([&](mlir::FuncOp func) {
|
||||
int wfn_id = 0;
|
||||
|
||||
// TODO: For now do not attempt to use nested parallelism.
|
||||
if (func->getAttr("_dfr_work_function_attribute"))
|
||||
return;
|
||||
|
||||
SymbolTable symbolTable = mlir::SymbolTable::getNearestSymbolTable(func);
|
||||
std::vector<std::pair<RT::DataflowTaskOp, FuncOp>> outliningMap;
|
||||
|
||||
func.walk([&](RT::DataflowTaskOp op) {
|
||||
auto workFunctionName = Twine("_dfr_DFT_work_function__") +
|
||||
Twine(op->getParentOfType<FuncOp>().getName()) +
|
||||
Twine(wfn_id++);
|
||||
FuncOp outlinedFunc = outlineWorkFunction(op, workFunctionName.str());
|
||||
outliningMap.push_back(
|
||||
std::pair<RT::DataflowTaskOp, FuncOp>(op, outlinedFunc));
|
||||
symbolTable.insert(outlinedFunc);
|
||||
return WalkResult::advance();
|
||||
});
|
||||
|
||||
// Lower the DF task ops to RT dialect ops.
|
||||
for (auto mapping : outliningMap)
|
||||
lowerDataflowTaskOp(mapping.first, mapping.second);
|
||||
|
||||
// Issue _dfr_start/stop calls for this function
|
||||
if (!outliningMap.empty()) {
|
||||
OpBuilder builder(func.body());
|
||||
builder.setInsertionPointToStart(&func.body().front());
|
||||
auto dfrStartFunOp = mlir::LLVM::lookupOrCreateFn(
|
||||
func->getParentOfType<ModuleOp>(), "_dfr_start", {},
|
||||
LLVM::LLVMVoidType::get(func->getContext()));
|
||||
builder.create<LLVM::CallOp>(func.getLoc(), dfrStartFunOp,
|
||||
mlir::ValueRange(),
|
||||
ArrayRef<NamedAttribute>());
|
||||
|
||||
builder.setInsertionPoint(func.body().back().getTerminator());
|
||||
auto dfrStopFunOp = mlir::LLVM::lookupOrCreateFn(
|
||||
func->getParentOfType<ModuleOp>(), "_dfr_stop", {},
|
||||
LLVM::LLVMVoidType::get(func->getContext()));
|
||||
builder.create<LLVM::CallOp>(func.getLoc(), dfrStopFunOp,
|
||||
mlir::ValueRange(),
|
||||
ArrayRef<NamedAttribute>());
|
||||
}
|
||||
});
|
||||
}
|
||||
LowerDataflowTasksPass(bool debug) : debug(debug){};
|
||||
|
||||
protected:
|
||||
bool debug;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
std::unique_ptr<mlir::Pass> createLowerDataflowTasksPass(bool debug) {
|
||||
return std::make_unique<LowerDataflowTasksPass>(debug);
|
||||
}
|
||||
|
||||
} // end namespace zamalang
|
||||
} // end namespace mlir
|
||||
@@ -0,0 +1,310 @@
|
||||
#include <iostream>
|
||||
|
||||
#include <mlir/IR/BuiltinOps.h>
|
||||
#include <zamalang/Dialect/HLFHE/IR/HLFHEDialect.h>
|
||||
#include <zamalang/Dialect/HLFHE/IR/HLFHEOps.h>
|
||||
#include <zamalang/Dialect/HLFHE/IR/HLFHETypes.h>
|
||||
#include <zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h>
|
||||
#include <zamalang/Dialect/LowLFHE/IR/LowLFHEOps.h>
|
||||
#include <zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h>
|
||||
#include <zamalang/Dialect/RT/Analysis/Autopar.h>
|
||||
#include <zamalang/Dialect/RT/IR/RTDialect.h>
|
||||
#include <zamalang/Dialect/RT/IR/RTOps.h>
|
||||
#include <zamalang/Dialect/RT/IR/RTTypes.h>
|
||||
#include <zamalang/Support/math.h>
|
||||
|
||||
#include <llvm/IR/Instructions.h>
|
||||
#include <llvm/Support/Compiler.h>
|
||||
#include <mlir/Analysis/DataFlowAnalysis.h>
|
||||
#include <mlir/Conversion/LLVMCommon/ConversionTarget.h>
|
||||
#include <mlir/Conversion/LLVMCommon/Pattern.h>
|
||||
#include <mlir/Conversion/LLVMCommon/VectorPattern.h>
|
||||
#include <mlir/Dialect/LLVMIR/FunctionCallUtils.h>
|
||||
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
|
||||
#include <mlir/Dialect/MemRef/IR/MemRef.h>
|
||||
#include <mlir/Dialect/StandardOps/IR/Ops.h>
|
||||
#include <mlir/Dialect/StandardOps/Transforms/FuncConversions.h>
|
||||
#include <mlir/Dialect/StandardOps/Transforms/Passes.h>
|
||||
#include <mlir/IR/Attributes.h>
|
||||
#include <mlir/IR/BlockAndValueMapping.h>
|
||||
#include <mlir/IR/Builders.h>
|
||||
#include <mlir/IR/BuiltinAttributes.h>
|
||||
#include <mlir/IR/SymbolTable.h>
|
||||
#include <mlir/Pass/PassManager.h>
|
||||
#include <mlir/Support/LLVM.h>
|
||||
#include <mlir/Support/LogicalResult.h>
|
||||
#include <mlir/Transforms/Bufferize.h>
|
||||
#include <mlir/Transforms/DialectConversion.h>
|
||||
#include <mlir/Transforms/Passes.h>
|
||||
#include <mlir/Transforms/RegionUtils.h>
|
||||
#include <mlir/Transforms/Utils.h>
|
||||
#include <zamalang/Conversion/Utils/GenericOpTypeConversionPattern.h>
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include <zamalang/Dialect/RT/Analysis/Autopar.h.inc>
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
|
||||
namespace {
|
||||
|
||||
mlir::Type getVoidPtrI64Type(ConversionPatternRewriter &rewriter) {
|
||||
return mlir::LLVM::LLVMPointerType::get(
|
||||
mlir::IntegerType::get(rewriter.getContext(), 64));
|
||||
}
|
||||
|
||||
LLVM::LLVMFuncOp getOrInsertFuncOpDecl(mlir::Operation *op,
|
||||
llvm::StringRef funcName,
|
||||
LLVM::LLVMFunctionType funcType,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
// Check if the function is already in the symbol table
|
||||
auto module = op->getParentOfType<ModuleOp>();
|
||||
auto funcOp = module.lookupSymbol<LLVM::LLVMFuncOp>(funcName);
|
||||
if (!funcOp) {
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(module.getBody());
|
||||
funcOp =
|
||||
rewriter.create<LLVM::LLVMFuncOp>(op->getLoc(), funcName, funcType);
|
||||
funcOp.setPrivate();
|
||||
} else {
|
||||
if (!funcOp.isPrivate()) {
|
||||
op->emitError()
|
||||
<< "the function \"" << funcName
|
||||
<< "\" conflicts with the Dataflow Runtime API, please rename.";
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
return funcOp;
|
||||
}
|
||||
|
||||
// This function is only needed for debug purposes to inspect values
|
||||
// in the generated code - it is therefore not generally in use.
|
||||
LLVM_ATTRIBUTE_UNUSED void
|
||||
insertPrintDebugCall(ConversionPatternRewriter &rewriter, mlir::Operation *op,
|
||||
Value val) {
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
auto printFnType = LLVM::LLVMFunctionType::get(
|
||||
LLVM::LLVMVoidType::get(rewriter.getContext()), {}, /*isVariadic=*/true);
|
||||
auto printFnOp =
|
||||
getOrInsertFuncOpDecl(op, "_dfr_print_debug", printFnType, rewriter);
|
||||
rewriter.create<LLVM::CallOp>(op->getLoc(), printFnOp, val);
|
||||
}
|
||||
|
||||
struct MakeReadyFutureOpInterfaceLowering
|
||||
: public ConvertOpToLLVMPattern<RT::MakeReadyFutureOp> {
|
||||
using ConvertOpToLLVMPattern<RT::MakeReadyFutureOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(RT::MakeReadyFutureOp mrfOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
RT::MakeReadyFutureOp::Adaptor transformed(operands);
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
|
||||
// Normally this function takes a pointer as parameter
|
||||
auto mrfFuncType = LLVM::LLVMFunctionType::get(getVoidPtrI64Type(rewriter),
|
||||
{}, /*isVariadic=*/true);
|
||||
auto mrfFuncOp = getOrInsertFuncOpDecl(mrfOp, "_dfr_make_ready_future",
|
||||
mrfFuncType, rewriter);
|
||||
|
||||
// In order to support non pointer types, we need to allocate
|
||||
// explicitly space that we can reference as a base for the
|
||||
// future.
|
||||
auto allocFuncOp = mlir::LLVM::lookupOrCreateMallocFn(
|
||||
mrfOp->getParentOfType<ModuleOp>(), getIndexType());
|
||||
auto sizeBytes = getSizeInBytes(
|
||||
mrfOp.getLoc(), transformed.getOperands().getTypes().front(), rewriter);
|
||||
auto results = mlir::LLVM::createLLVMCall(
|
||||
rewriter, mrfOp.getLoc(), allocFuncOp, {sizeBytes}, getVoidPtrType());
|
||||
Value allocatedPtr = rewriter.create<mlir::LLVM::BitcastOp>(
|
||||
mrfOp.getLoc(),
|
||||
mlir::LLVM::LLVMPointerType::get(
|
||||
transformed.getOperands().getTypes().front()),
|
||||
results[0]);
|
||||
rewriter.create<LLVM::StoreOp>(
|
||||
mrfOp.getLoc(), transformed.getOperands().front(), allocatedPtr);
|
||||
rewriter.replaceOpWithNewOp<LLVM::CallOp>(mrfOp, mrfFuncOp, allocatedPtr);
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
struct AwaitFutureOpInterfaceLowering
|
||||
: public ConvertOpToLLVMPattern<RT::AwaitFutureOp> {
|
||||
using ConvertOpToLLVMPattern<RT::AwaitFutureOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(RT::AwaitFutureOp afOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
RT::AwaitFutureOp::Adaptor transformed(operands);
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
auto afFuncType = LLVM::LLVMFunctionType::get(
|
||||
mlir::LLVM::LLVMPointerType::get(getVoidPtrI64Type(rewriter)),
|
||||
{getVoidPtrI64Type(rewriter)});
|
||||
auto afFuncOp =
|
||||
getOrInsertFuncOpDecl(afOp, "_dfr_await_future", afFuncType, rewriter);
|
||||
auto afCallOp = rewriter.create<LLVM::CallOp>(afOp.getLoc(), afFuncOp,
|
||||
transformed.getOperands());
|
||||
Value futVal = rewriter.create<mlir::LLVM::BitcastOp>(
|
||||
afOp.getLoc(),
|
||||
mlir::LLVM::LLVMPointerType::get(
|
||||
(*getTypeConverter()).convertType(afOp.getResult().getType())),
|
||||
afCallOp.getResult(0));
|
||||
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(afOp, futVal);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
struct CreateAsyncTaskOpInterfaceLowering
|
||||
: public ConvertOpToLLVMPattern<RT::CreateAsyncTaskOp> {
|
||||
using ConvertOpToLLVMPattern<RT::CreateAsyncTaskOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(RT::CreateAsyncTaskOp catOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
RT::CreateAsyncTaskOp::Adaptor transformed(operands);
|
||||
auto catFuncType =
|
||||
LLVM::LLVMFunctionType::get(getVoidType(), {}, /*isVariadic=*/true);
|
||||
auto catFuncOp = getOrInsertFuncOpDecl(catOp, "_dfr_create_async_task",
|
||||
catFuncType, rewriter);
|
||||
rewriter.replaceOpWithNewOp<LLVM::CallOp>(catOp, catFuncOp,
|
||||
transformed.getOperands());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
struct DeallocateFutureOpInterfaceLowering
|
||||
: public ConvertOpToLLVMPattern<RT::DeallocateFutureOp> {
|
||||
using ConvertOpToLLVMPattern<RT::DeallocateFutureOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(RT::DeallocateFutureOp dfOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
RT::DeallocateFutureOp::Adaptor transformed(operands);
|
||||
auto dfFuncType = LLVM::LLVMFunctionType::get(
|
||||
getVoidType(), {getVoidPtrI64Type(rewriter)});
|
||||
auto dfFuncOp = getOrInsertFuncOpDecl(dfOp, "_dfr_deallocate_future",
|
||||
dfFuncType, rewriter);
|
||||
rewriter.replaceOpWithNewOp<LLVM::CallOp>(dfOp, dfFuncOp,
|
||||
transformed.getOperands());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
struct DeallocateFutureDataOpInterfaceLowering
|
||||
: public ConvertOpToLLVMPattern<RT::DeallocateFutureDataOp> {
|
||||
using ConvertOpToLLVMPattern<
|
||||
RT::DeallocateFutureDataOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(RT::DeallocateFutureDataOp dfdOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
RT::DeallocateFutureDataOp::Adaptor transformed(operands);
|
||||
auto dfdFuncType = LLVM::LLVMFunctionType::get(
|
||||
getVoidType(), {getVoidPtrI64Type(rewriter)});
|
||||
auto dfdFuncOp = getOrInsertFuncOpDecl(dfdOp, "_dfr_deallocate_future_data",
|
||||
dfdFuncType, rewriter);
|
||||
rewriter.replaceOpWithNewOp<LLVM::CallOp>(dfdOp, dfdFuncOp,
|
||||
transformed.getOperands());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
struct BuildReturnPtrPlaceholderOpInterfaceLowering
|
||||
: public ConvertOpToLLVMPattern<RT::BuildReturnPtrPlaceholderOp> {
|
||||
using ConvertOpToLLVMPattern<
|
||||
RT::BuildReturnPtrPlaceholderOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(RT::BuildReturnPtrPlaceholderOp befOp,
|
||||
ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
|
||||
// BuildReturnPtrPlaceholder is a placeholder for generating a memory
|
||||
// location where a pointer to allocated memory can be written so
|
||||
// that we can return outputs from task work function.
|
||||
Value one = rewriter.create<arith::ConstantOp>(
|
||||
befOp.getLoc(),
|
||||
(*getTypeConverter()).convertType(rewriter.getIndexType()),
|
||||
rewriter.getIntegerAttr(
|
||||
(*getTypeConverter()).convertType(rewriter.getIndexType()), 1));
|
||||
rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(
|
||||
befOp, mlir::LLVM::LLVMPointerType::get(getVoidPtrI64Type(rewriter)),
|
||||
one,
|
||||
/*alignment=*/
|
||||
rewriter.getIntegerAttr(
|
||||
(*getTypeConverter()).convertType(rewriter.getIndexType()), 0));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
struct DerefReturnPtrPlaceholderOpInterfaceLowering
|
||||
: public ConvertOpToLLVMPattern<RT::DerefReturnPtrPlaceholderOp> {
|
||||
using ConvertOpToLLVMPattern<
|
||||
RT::DerefReturnPtrPlaceholderOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(RT::DerefReturnPtrPlaceholderOp drppOp,
|
||||
ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
RT::DerefReturnPtrPlaceholderOp::Adaptor transformed(operands);
|
||||
|
||||
// DerefReturnPtrPlaceholder is a placeholder for generating a
|
||||
// dereference operation for the pointer used to get results from
|
||||
// task.
|
||||
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
|
||||
drppOp, transformed.getOperands().front());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
struct DerefWorkFunctionArgumentPtrPlaceholderOpInterfaceLowering
|
||||
: public ConvertOpToLLVMPattern<
|
||||
RT::DerefWorkFunctionArgumentPtrPlaceholderOp> {
|
||||
using ConvertOpToLLVMPattern<
|
||||
RT::DerefWorkFunctionArgumentPtrPlaceholderOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(RT::DerefWorkFunctionArgumentPtrPlaceholderOp dwfappOp,
|
||||
ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
RT::DerefWorkFunctionArgumentPtrPlaceholderOp::Adaptor transformed(
|
||||
operands);
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
|
||||
// DerefWorkFunctionArgumentPtrPlaceholderOp is a placeholder for
|
||||
// generating a dereference operation for the pointer used to pass
|
||||
// arguments to the task.
|
||||
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
|
||||
dwfappOp, transformed.getOperands().front());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
struct WorkFunctionReturnOpInterfaceLowering
|
||||
: public ConvertOpToLLVMPattern<RT::WorkFunctionReturnOp> {
|
||||
using ConvertOpToLLVMPattern<
|
||||
RT::WorkFunctionReturnOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(RT::WorkFunctionReturnOp wfrOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
RT::WorkFunctionReturnOp::Adaptor transformed(operands);
|
||||
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(
|
||||
wfrOp, transformed.getOperands().front(),
|
||||
transformed.getOperands().back());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
|
||||
void mlir::zamalang::populateRTToLLVMConversionPatterns(
|
||||
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
|
||||
// clang-format off
|
||||
patterns.add<
|
||||
MakeReadyFutureOpInterfaceLowering,
|
||||
AwaitFutureOpInterfaceLowering,
|
||||
BuildReturnPtrPlaceholderOpInterfaceLowering,
|
||||
DerefReturnPtrPlaceholderOpInterfaceLowering,
|
||||
DerefWorkFunctionArgumentPtrPlaceholderOpInterfaceLowering,
|
||||
CreateAsyncTaskOpInterfaceLowering,
|
||||
DeallocateFutureOpInterfaceLowering,
|
||||
DeallocateFutureDataOpInterfaceLowering,
|
||||
WorkFunctionReturnOpInterfaceLowering>(converter);
|
||||
// clang-format on
|
||||
}
|
||||
@@ -1 +1,2 @@
|
||||
add_subdirectory(Analysis)
|
||||
add_subdirectory(IR)
|
||||
|
||||
@@ -1,3 +1,16 @@
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
#include "mlir/IR/FunctionImplementation.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
|
||||
#include "zamalang/Dialect/RT/IR/RTDialect.h"
|
||||
#include "zamalang/Dialect/RT/IR/RTOps.h"
|
||||
#include "zamalang/Dialect/RT/IR/RTTypes.h"
|
||||
@@ -24,7 +37,7 @@ void RTDialect::initialize() {
|
||||
::mlir::Type RTDialect::parseType(::mlir::DialectAsmParser &parser) const {
|
||||
mlir::Type type;
|
||||
if (parser.parseOptionalKeyword("future").succeeded()) {
|
||||
generatedTypeParser(this->getContext(), parser, "future", type);
|
||||
generatedTypeParser(parser, "future", type);
|
||||
return type;
|
||||
}
|
||||
return type;
|
||||
@@ -35,4 +48,4 @@ void RTDialect::printType(::mlir::Type type,
|
||||
if (generatedTypePrinter(type, printer).failed()) {
|
||||
printer.printType(type);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Region.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
|
||||
@@ -6,3 +10,17 @@
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "zamalang/Dialect/RT/IR/RTOps.cpp.inc"
|
||||
|
||||
using namespace mlir::zamalang::RT;
|
||||
|
||||
void DataflowTaskOp::build(
|
||||
::mlir::OpBuilder &builder, ::mlir::OperationState &result,
|
||||
::mlir::TypeRange resultTypes, ::mlir::ValueRange operands,
|
||||
::llvm::ArrayRef<::mlir::NamedAttribute> attributes) {
|
||||
result.addOperands(operands);
|
||||
result.addAttributes(attributes);
|
||||
Region *reg = result.addRegion();
|
||||
Block *body = new Block();
|
||||
reg->push_back(body);
|
||||
result.addTypes(resultTypes);
|
||||
}
|
||||
|
||||
@@ -28,6 +28,7 @@ add_mlir_library(ZamalangSupport
|
||||
LowLFHEUnparametrize
|
||||
MLIRLowerableDialectsToLLVM
|
||||
HLFHEDialectAnalysis
|
||||
RTDialectAnalysis
|
||||
|
||||
MLIRExecutionEngine
|
||||
${LLVM_PTHREAD_LIB}
|
||||
|
||||
@@ -84,6 +84,8 @@ void CompilerEngine::setVerifyDiagnostics(bool v) {
|
||||
this->verifyDiagnostics = v;
|
||||
}
|
||||
|
||||
void CompilerEngine::setAutoParallelize(bool v) { this->autoParallelize = v; }
|
||||
|
||||
void CompilerEngine::setGenerateClientParameters(bool v) {
|
||||
this->generateClientParameters = v;
|
||||
}
|
||||
@@ -215,6 +217,13 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
|
||||
return errorDiag("Tiling of HLFHELinalg operations failed");
|
||||
}
|
||||
|
||||
// Auto parallelization
|
||||
if (this->autoParallelize &&
|
||||
mlir::zamalang::pipeline::autopar(mlirContext, module, enablePass)
|
||||
.failed()) {
|
||||
return StreamStringError("Auto parallelization failed");
|
||||
}
|
||||
|
||||
if (target == Target::HLFHE)
|
||||
return std::move(res);
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
#include <zamalang/Conversion/Passes.h>
|
||||
#include <zamalang/Dialect/HLFHE/Analysis/MANP.h>
|
||||
#include <zamalang/Dialect/HLFHELinalg/Transforms/Tiling.h>
|
||||
#include <zamalang/Dialect/RT/Analysis/Autopar.h>
|
||||
#include <zamalang/Support/Pipeline.h>
|
||||
#include <zamalang/Support/logging.h>
|
||||
#include <zamalang/Support/math.h>
|
||||
@@ -102,6 +103,17 @@ getFHEConstraintsFromHLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
return ret;
|
||||
}
|
||||
|
||||
mlir::LogicalResult autopar(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass) {
|
||||
mlir::PassManager pm(&context);
|
||||
pipelinePrinting("AutoPar", pm, context);
|
||||
|
||||
addPotentiallyNestedPass(
|
||||
pm, mlir::zamalang::createBuildDataflowTaskGraphPass(), enablePass);
|
||||
|
||||
return pm.run(module.getOperation());
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
tileMarkedHLFHELinalg(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass) {
|
||||
@@ -190,8 +202,18 @@ lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
enablePass);
|
||||
addPotentiallyNestedPass(pm, mlir::createSCFBufferizePass(), enablePass);
|
||||
addPotentiallyNestedPass(pm, mlir::createFuncBufferizePass(), enablePass);
|
||||
addPotentiallyNestedPass(
|
||||
pm, mlir::zamalang::createBufferizeDataflowTaskOpsPass(), enablePass);
|
||||
addPotentiallyNestedPass(pm, mlir::createFinalizingBufferizePass(),
|
||||
enablePass);
|
||||
|
||||
// Lower Dataflow tasks to DRF
|
||||
addPotentiallyNestedPass(pm, mlir::zamalang::createFixupDataflowTaskOpsPass(),
|
||||
enablePass);
|
||||
addPotentiallyNestedPass(pm, mlir::zamalang::createLowerDataflowTasksPass(),
|
||||
enablePass);
|
||||
addPotentiallyNestedPass(pm, mlir::createConvertLinalgToLoopsPass(),
|
||||
enablePass);
|
||||
addPotentiallyNestedPass(pm, mlir::createLowerToCFGPass(), enablePass);
|
||||
|
||||
// Convert to MLIR LLVM Dialect
|
||||
|
||||
@@ -123,6 +123,11 @@ llvm::cl::opt<bool> splitInputFile(
|
||||
"chunk independently"),
|
||||
llvm::cl::init(false));
|
||||
|
||||
llvm::cl::opt<bool> autoParallelize(
|
||||
"parallelize",
|
||||
llvm::cl::desc("Generate (and execute if JIT) parallel code"),
|
||||
llvm::cl::init(false));
|
||||
|
||||
llvm::cl::opt<std::string> jitFuncName(
|
||||
"jit-funcname",
|
||||
llvm::cl::desc("Name of the function to execute, default 'main'"),
|
||||
@@ -229,7 +234,7 @@ mlir::LogicalResult processInputBuffer(
|
||||
llvm::Optional<size_t> overrideMaxEintPrecision,
|
||||
llvm::Optional<size_t> overrideMaxMANP, bool verifyDiagnostics,
|
||||
llvm::Optional<llvm::ArrayRef<int64_t>> hlfhelinalgTileSizes,
|
||||
llvm::raw_ostream &os,
|
||||
bool autoParallelize, llvm::raw_ostream &os,
|
||||
std::shared_ptr<mlir::zamalang::CompilerEngine::Library> outputLib) {
|
||||
std::shared_ptr<mlir::zamalang::CompilationContext> ccx =
|
||||
mlir::zamalang::CompilationContext::createShared();
|
||||
@@ -237,6 +242,7 @@ mlir::LogicalResult processInputBuffer(
|
||||
mlir::zamalang::JitCompilerEngine ce{ccx};
|
||||
|
||||
ce.setVerifyDiagnostics(verifyDiagnostics);
|
||||
ce.setAutoParallelize(autoParallelize);
|
||||
if (cmdline::passes.size() != 0) {
|
||||
ce.setEnablePass([](mlir::Pass *pass) {
|
||||
return std::any_of(
|
||||
@@ -404,7 +410,8 @@ mlir::LogicalResult compilerMain(int argc, char **argv) {
|
||||
std::move(inputBuffer), fileName, cmdline::action,
|
||||
cmdline::jitFuncName, cmdline::jitArgs,
|
||||
cmdline::assumeMaxEintPrecision, cmdline::assumeMaxMANP,
|
||||
cmdline::verifyDiagnostics, hlfhelinalgTileSizes, os, outputLib);
|
||||
cmdline::verifyDiagnostics, hlfhelinalgTileSizes,
|
||||
cmdline::autoParallelize, os, outputLib);
|
||||
};
|
||||
auto &os = output->os();
|
||||
auto res = mlir::failure();
|
||||
|
||||
Reference in New Issue
Block a user