feat(compiler): add Dataflow/RT dialect and code generation for dataflow auto parallelization.

This commit is contained in:
Antoniu Pop
2021-12-20 12:20:33 +00:00
committed by Antoniu Pop
parent db683f4a0e
commit cdca7ca6f7
26 changed files with 1418 additions and 10 deletions

View File

@@ -19,7 +19,7 @@ class HLFHE_Op<string mnemonic, list<OpTrait> traits = []> :
Op<HLFHE_Dialect, mnemonic, traits>;
// Generates an encrypted zero constant
def ZeroEintOp : HLFHE_Op<"zero"> {
def ZeroEintOp : HLFHE_Op<"zero", [NoSideEffect]> {
let arguments = (ins);
let results = (outs EncryptedIntegerType:$out);
}

View File

@@ -0,0 +1,28 @@
#ifndef ZAMALANG_DIALECT_RT_ANALYSIS_AUTOPAR_H
#define ZAMALANG_DIALECT_RT_ANALYSIS_AUTOPAR_H
#include <functional>
#include <mlir/Pass/Pass.h>
#include <zamalang/Dialect/RT/IR/RTOps.h>
namespace mlir {
class LLVMTypeConverter;
class BufferizeTypeConverter;
class RewritePatternSet;
namespace zamalang {
std::unique_ptr<mlir::Pass>
createBuildDataflowTaskGraphPass(bool debug = false);
std::unique_ptr<mlir::Pass> createLowerDataflowTasksPass(bool debug = false);
std::unique_ptr<mlir::Pass>
createBufferizeDataflowTaskOpsPass(bool debug = false);
std::unique_ptr<mlir::Pass> createFixupDataflowTaskOpsPass(bool debug = false);
void populateRTToLLVMConversionPatterns(mlir::LLVMTypeConverter &converter,
mlir::RewritePatternSet &patterns);
void populateRTBufferizePatterns(mlir::BufferizeTypeConverter &typeConverter,
mlir::RewritePatternSet &patterns);
} // namespace zamalang
} // namespace mlir
#endif

View File

@@ -0,0 +1,87 @@
#ifndef ZAMALANG_DIALECT_RT_ANALYSIS_AUTOPAR
#define ZAMALANG_DIALECT_RT_ANALYSIS_AUTOPAR
include "mlir/Pass/PassBase.td"
def BuildDataflowTaskGraph : Pass<"BuildDataflowTaskGraph", "mlir::ModuleOp"> {
let summary =
"Identify profitable dataflow tasks and build DataflowTaskGraph.";
let description = [{
This pass builds a dataflow graph out of a HLFHE program.
In its current incarnation, it considers some heavier weight
operations (e.g., HLFHELinalg Dot and Matmult or bootstraps) as
candidates for being executed in a discrete task, and then
sinks within the task the lighter weight operation that do not
increase the graph cut (amount of dependences in or out).
The output is a program partitioned in RT::DataflowTaskOp that
expose task dependences as arguments and results of the
DataflowTaskOp.
Example:
```mlir
func @main(%arg0: tensor<3x4x!HLFHE.eint<2>>, %arg1: tensor<4x2xi3>) -> tensor<3x2x!HLFHE.eint<2>> {
%0 = "HLFHELinalg.matmul_eint_int"(%arg0, %arg1) : (tensor<3x4x!HLFHE.eint<2>>, tensor<4x2xi3>) -> tensor<3x2x!HLFHE.eint<2>>
return %0 : tensor<3x2x!HLFHE.eint<2>>
}
```
Will result in generating a dataflow task for the Matmul operation:
```mlir
func @main(%arg0: tensor<3x4x!HLFHE.eint<2>>, %arg1: tensor<4x2xi3>) -> tensor<3x2x!HLFHE.eint<2>> {
%0 = "RT.dataflow_task"(%arg0, %arg1) ( {
%1 = "HLFHELinalg.matmul_eint_int"(%arg0, %arg1) : (tensor<3x4x!HLFHE.eint<2>>, tensor<4x2xi3>) -> tensor<3x2x!HLFHE.eint<2>>
"RT.dataflow_yield"(%1) : (tensor<3x2x!HLFHE.eint<2>>) -> ()
}) : (tensor<3x4x!HLFHE.eint<2>>, tensor<4x2xi3>) -> tensor<3x2x!HLFHE.eint<2>>
return %0 : tensor<3x2x!HLFHE.eint<2>>
}
```
}];
}
def BufferizeDataflowTaskOps : Pass<"BufferizeDataflowTaskOps", "mlir::ModuleOp"> {
let summary =
"Bufferize DataflowTaskOp(s).";
let description = [{
This pass lowers DataflowTaskOp arguments and results from tensors
to mlir::memref. It also lowers the arguments of DataflowYieldOp.
}];
}
def FixupDataflowTaskOps : Pass<"FixupDataflowTaskOps", "mlir::ModuleOp"> {
let summary =
"Fix DataflowTaskOp(s) before lowering.";
let description = [{
This pass fixes up code changes that intervene between the
BuildDataflowTaskGraph pass and the lowering of the taskgraph to
LLVMIR and calls to the DFR runtime system.
In particular, some operations (e.g., constants, dimension
operations, etc.) can be used within the task while only defined
outside. In most cases cloning and sinking these operations in the
task is the simplest to avoid adding dependences.
}];
}
def LowerDataflowTasks : Pass<"LowerDataflowTasks", "mlir::ModuleOp"> {
let summary =
"Outline the body of a DataflowTaskOp into a separate function which will serve as a task work function and lower the task graph to RT.";
let description = [{
This pass lowers a DataflowTaskGraph to the RT dialect, outlining
DataflowTaskOp into separate work functions and introducing the
necessary operations to communicate and synchronize execution via
futures.
}];
}
#endif

View File

@@ -0,0 +1,6 @@
set(LLVM_TARGET_DEFINITIONS Autopar.td)
mlir_tablegen(Autopar.h.inc -gen-pass-decls -name Analysis)
mlir_tablegen(Autopar.capi.h.inc -gen-pass-capi-header --prefix Analysis)
mlir_tablegen(Autopar.capi.cpp.inc -gen-pass-capi-impl --prefix Analysis)
add_public_tablegen_target(AutoparPassIncGen)

View File

@@ -1 +1,2 @@
add_subdirectory(Analysis)
add_subdirectory(IR)

View File

@@ -3,6 +3,8 @@
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/DataLayoutInterfaces.td"
include "zamalang/Dialect/RT/IR/RTDialect.td"
include "zamalang/Dialect/RT/IR/RTTypes.td"
@@ -16,9 +18,24 @@ def DataflowTaskOp : RT_Op<"dataflow_task", [SingleBlockImplicitTerminator<"Data
let regions = (region AnyRegion:$body);
let builders = [
OpBuilder<(ins
CArg<"TypeRange", "{}">: $resultTypes,
CArg<"ValueRange", "{}">: $operands,
CArg<"ArrayRef<NamedAttribute>", "{}">: $attrs)>
];
let skipDefaultBuilders = 1;
let summary = "Dataflow task operation";
let description = [{
`RT.dataflow_task` allows to specify a task that will be concurrently executed when their operands are ready.
`RT.dataflow_task` allows to specify a task that will be concurrently
executed when their operands are ready. Operands are either the
results of computation in other `RT.dataflow_task` (dataflow
dependences) or obtained from the execution context (immediate
operands). Operands are synchronized using futures and, in the case
of immediate operands, copied when the task is created. Caution is
required when the operand is a pointer as no deep copy will occur.
Example:
@@ -43,6 +60,7 @@ func @test(%0 : i64): (i64, i64) {
}) : (i64, i64) -> (i64, i64)
return %3, %4 : (i64, i64)
}
```
}];
}
@@ -52,7 +70,7 @@ def DataflowYieldOp : RT_Op<"dataflow_yield", [ReturnLike, Terminator]> {
let summary = "Dataflow yield operation";
let description = [{
`RT.dataflow_yield` is a special terminator operation for blocks inside the region
in `RT.dataflow_task`. It allows to specify the returns values of a `RT.dataflow_task`.
in `RT.dataflow_task`. It allows to specify the return values of a `RT.dataflow_task`.
Example:
@@ -64,4 +82,64 @@ Example:
}];
}
def MakeReadyFutureOp : RT_Op<"make_ready_future"> {
let arguments = (ins AnyType: $input);
let results = (outs RT_Future: $output);
let summary = "Build a ready future.";
let description = [{
Data passed to dataflow tasks must be encapsulated in futures,
including immediate operands. These must be converted into futures
using `RT.make_ready_future`.
}];
}
def AwaitFutureOp : RT_Op<"await_future"> {
let arguments = (ins RT_Future: $input);
let results = (outs AnyType: $output);
let summary = "Wait for a future and access its data.";
let description = [{
The results of a dataflow task are always futures which could be
further used as inputs to subsequent tasks. When the result of a task
is needed in the outer execution context, the result future needs to
be synchronized and its data accessed using `RT.await_future`.
}];
}
def CreateAsyncTaskOp : RT_Op<"create_async_task"> {
let arguments = (ins SymbolRefAttr:$workfn,
Variadic<AnyType>:$list);
let results = (outs );
let summary = "Create a dataflow task.";
}
def DeallocateFutureOp : RT_Op<"deallocate_future"> {
let arguments = (ins RT_Future: $input);
let results = (outs );
}
def DeallocateFutureDataOp : RT_Op<"deallocate_future_data"> {
let arguments = (ins RT_Future: $input);
let results = (outs );
}
def BuildReturnPtrPlaceholderOp : RT_Op<"build_return_ptr_placeholder"> {
let arguments = (ins );
let results = (outs RT_Pointer: $output);
}
def DerefReturnPtrPlaceholderOp : RT_Op<"deref_return_ptr_placeholder"> {
let arguments = (ins RT_Pointer: $input);
let results = (outs RT_Future: $output);
}
def DerefWorkFunctionArgumentPtrPlaceholderOp : RT_Op<"deref_work_function_argument_ptr_placeholder"> {
let arguments = (ins RT_Pointer: $input);
let results = (outs AnyType: $output);
}
def WorkFunctionReturnOp : RT_Op<"work_function_return"> {
let arguments = (ins AnyType:$in, AnyType:$out);
let results = (outs );
}
#endif

View File

@@ -46,7 +46,40 @@ def RT_Future : RT_Type<"Future"> {
return Type();
return get($_ctxt, elementType);
}];
//let genVerifyDecl = 1;
}
def RT_Pointer : RT_Type<"Pointer"> {
let mnemonic = "rtptr";
let summary = "Pointer to a parameterized element type";
let description = [{
}];
let parameters = (ins "Type":$elementType);
let builders = [
TypeBuilderWithInferredContext<(ins "Type":$elementType), [{
return $_get(elementType.getContext(), elementType);
}]>
];
let printer = [{
$_printer << "rtptr<";
$_printer.printType(getElementType());
$_printer << ">";
}];
let parser = [{
if ($_parser.parseLess())
return Type();
Type elementType;
if ($_parser.parseType(elementType))
return Type();
if ($_parser.parseGreater())
return Type();
return get($_ctxt, elementType);
}];
}
#endif

View File

@@ -122,7 +122,7 @@ public:
CompilerEngine(std::shared_ptr<CompilationContext> compilationContext)
: overrideMaxEintPrecision(), overrideMaxMANP(),
clientParametersFuncName(), verifyDiagnostics(false),
generateClientParameters(false),
autoParallelize(false), generateClientParameters(false),
enablePass([](mlir::Pass *pass) { return true; }),
compilationContext(compilationContext) {}
@@ -146,6 +146,7 @@ public:
void setMaxEintPrecision(size_t v);
void setMaxMANP(size_t v);
void setVerifyDiagnostics(bool v);
void setAutoParallelize(bool v);
void setGenerateClientParameters(bool v);
void setClientParametersFuncName(const llvm::StringRef &name);
void setHLFHELinalgTileSizes(llvm::ArrayRef<int64_t> sizes);
@@ -158,6 +159,7 @@ protected:
llvm::Optional<std::vector<int64_t>> hlfhelinalgTileSizes;
bool verifyDiagnostics;
bool autoParallelize;
bool generateClientParameters;
std::function<bool(mlir::Pass *)> enablePass;

View File

@@ -12,6 +12,9 @@ namespace mlir {
namespace zamalang {
namespace pipeline {
mlir::LogicalResult autopar(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass);
llvm::Expected<llvm::Optional<mlir::zamalang::V0FHEConstraint>>
getFHEConstraintsFromHLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass);

View File

@@ -11,6 +11,7 @@
#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h"
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h"
#include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h"
#include "zamalang/Dialect/RT/IR/RTOps.h"
namespace {
struct HLFHEToMidLFHEPass : public HLFHEToMidLFHEBase<HLFHEToMidLFHEPass> {
@@ -92,6 +93,12 @@ void HLFHEToMidLFHEPass::runOnOperation() {
converter);
mlir::populateFuncOpTypeConversionPattern(patterns, converter);
// Conversion of RT Dialect Ops
patterns.add<mlir::zamalang::GenericTypeConverterPattern<
mlir::zamalang::RT::DataflowTaskOp>>(patterns.getContext(), converter);
mlir::zamalang::addDynamicallyLegalTypeOp<mlir::zamalang::RT::DataflowTaskOp>(
target, converter);
// Apply conversion
if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) {
this->signalPassFailure();

View File

@@ -7,6 +7,7 @@
#include "zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h"
#include "zamalang/Dialect/LowLFHE/IR/LowLFHEOps.h"
#include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h"
#include "zamalang/Dialect/RT/IR/RTOps.h"
#include "zamalang/Support/Constants.h"
/// LowLFHEUnparametrizeTypeConverter is a type converter that unparametrize
@@ -123,6 +124,12 @@ void LowLFHEUnparametrizePass::runOnOperation() {
patterns.getContext(), converter);
mlir::zamalang::addDynamicallyLegalTypeOp<mlir::CallOp>(target, converter);
// Conversion of RT Dialect Ops
patterns.add<mlir::zamalang::GenericTypeConverterPattern<
mlir::zamalang::RT::DataflowTaskOp>>(patterns.getContext(), converter);
mlir::zamalang::addDynamicallyLegalTypeOp<mlir::zamalang::RT::DataflowTaskOp>(
target, converter);
// Apply conversion
if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) {
this->signalPassFailure();

View File

@@ -19,6 +19,8 @@
#include "zamalang/Conversion/Passes.h"
#include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h"
#include "zamalang/Dialect/RT/Analysis/Autopar.h"
#include "zamalang/Dialect/RT/IR/RTTypes.h"
namespace {
struct MLIRLowerableDialectsToLLVMPass
@@ -52,6 +54,7 @@ void MLIRLowerableDialectsToLLVMPass::runOnOperation() {
// Setup the set of the patterns rewriter. At this point we want to
// convert the `scf` operations to `std` and `std` operations to `llvm`.
mlir::RewritePatternSet patterns(&getContext());
mlir::zamalang::populateRTToLLVMConversionPatterns(typeConverter, patterns);
mlir::populateStdToLLVMConversionPatterns(typeConverter, patterns);
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
patterns);
@@ -72,10 +75,28 @@ MLIRLowerableDialectsToLLVMPass::convertTypes(mlir::Type type) {
type.isa<mlir::zamalang::LowLFHE::LweBootstrapKeyType>() ||
type.isa<mlir::zamalang::LowLFHE::ContextType>() ||
type.isa<mlir::zamalang::LowLFHE::ForeignPlaintextListType>() ||
type.isa<mlir::zamalang::LowLFHE::PlaintextListType>()) {
type.isa<mlir::zamalang::LowLFHE::PlaintextListType>() ||
type.isa<mlir::zamalang::RT::FutureType>()) {
return mlir::LLVM::LLVMPointerType::get(
mlir::IntegerType::get(type.getContext(), 64));
}
if (type.isa<mlir::zamalang::RT::PointerType>()) {
mlir::LowerToLLVMOptions options(type.getContext());
mlir::LLVMTypeConverter typeConverter(type.getContext(), options);
typeConverter.addConversion(convertTypes);
typeConverter.addConversion(
[&](mlir::zamalang::LowLFHE::PlaintextType type) {
return mlir::IntegerType::get(type.getContext(), 64);
});
typeConverter.addConversion(
[&](mlir::zamalang::LowLFHE::CleartextType type) {
return mlir::IntegerType::get(type.getContext(), 64);
});
mlir::Type subtype =
type.dyn_cast<mlir::zamalang::RT::PointerType>().getElementType();
mlir::Type convertedSubtype = typeConverter.convertType(subtype);
return mlir::LLVM::LLVMPointerType::get(convertedSubtype);
}
return llvm::None;
}

View File

@@ -7,6 +7,7 @@
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h"
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEOps.h"
#include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h"
#include "zamalang/Dialect/RT/IR/RTOps.h"
#include "zamalang/Support/Constants.h"
namespace {
@@ -300,6 +301,12 @@ void MidLFHEGlobalParametrizationPass::runOnOperation() {
mlir::zamalang::populateWithTensorTypeConverterPatterns(patterns, target,
converter);
// Conversion of RT Dialect Ops
patterns.add<mlir::zamalang::GenericTypeConverterPattern<
mlir::zamalang::RT::DataflowTaskOp>>(patterns.getContext(), converter);
mlir::zamalang::addDynamicallyLegalTypeOp<
mlir::zamalang::RT::DataflowTaskOp>(target, converter);
// Apply conversion
if (mlir::applyPartialConversion(op, target, std::move(patterns))
.failed()) {

View File

@@ -11,6 +11,7 @@
#include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h"
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h"
#include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h"
#include "zamalang/Dialect/RT/IR/RTOps.h"
namespace {
struct MidLFHEToLowLFHEPass
@@ -89,6 +90,12 @@ void MidLFHEToLowLFHEPass::runOnOperation() {
converter);
mlir::populateFuncOpTypeConversionPattern(patterns, converter);
// Conversion of RT Dialect Ops
patterns.add<mlir::zamalang::GenericTypeConverterPattern<
mlir::zamalang::RT::DataflowTaskOp>>(patterns.getContext(), converter);
mlir::zamalang::addDynamicallyLegalTypeOp<mlir::zamalang::RT::DataflowTaskOp>(
target, converter);
// Apply conversion
if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) {
this->signalPassFailure();

View File

@@ -0,0 +1,119 @@
#include <iostream>
#include <zamalang/Dialect/RT/Analysis/Autopar.h>
#include <zamalang/Dialect/RT/IR/RTDialect.h>
#include <zamalang/Dialect/RT/IR/RTOps.h>
#include <zamalang/Dialect/RT/IR/RTTypes.h>
#include <llvm/IR/Instructions.h>
#include <mlir/Dialect/MemRef/IR/MemRef.h>
#include <mlir/IR/BlockAndValueMapping.h>
#include <mlir/IR/Builders.h>
#include <mlir/Transforms/Bufferize.h>
#include <mlir/Transforms/RegionUtils.h>
#include <zamalang/Conversion/Utils/GenericOpTypeConversionPattern.h>
#define GEN_PASS_CLASSES
#include <zamalang/Dialect/RT/Analysis/Autopar.h.inc>
namespace mlir {
namespace zamalang {
namespace {
class BufferizeDataflowYieldOp
: public OpConversionPattern<RT::DataflowYieldOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(RT::DataflowYieldOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
RT::DataflowYieldOp::Adaptor transformed(operands);
rewriter.replaceOpWithNewOp<RT::DataflowYieldOp>(op, mlir::TypeRange(),
transformed.getOperands());
return success();
}
};
} // namespace
namespace {
class BufferizeDataflowTaskOp : public OpConversionPattern<RT::DataflowTaskOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(RT::DataflowTaskOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
RT::DataflowTaskOp::Adaptor transformed(operands);
mlir::OpBuilder::InsertionGuard guard(rewriter);
SmallVector<Type> newResults;
(void)getTypeConverter()->convertTypes(op.getResultTypes(), newResults);
auto newop = rewriter.create<RT::DataflowTaskOp>(op.getLoc(), newResults,
transformed.getOperands());
// We cannot clone here as cloned ops must be legalized (so this
// would break on the YieldOp). Instead use mergeBlocks which
// moves the ops instead of cloning.
rewriter.mergeBlocks(op.getBody(), newop.getBody(),
newop.getBody()->getArguments());
// Because of previous bufferization there are buffer cast ops
// that have been generated for the previously tensor results of
// some tasks. These cannot just be replaced directly as the
// task's results would still be live.
for (auto res : llvm::enumerate(op.getResults())) {
// If this result is getting bufferized ...
if (res.value().getType() !=
getTypeConverter()->convertType(res.value().getType())) {
for (auto &use : llvm::make_early_inc_range(res.value().getUses())) {
// ... and its uses are in `BufferCastOp`s, then we
// replace further uses of the buffer cast.
if (isa<mlir::memref::BufferCastOp>(use.getOwner())) {
rewriter.replaceOp(use.getOwner(), {newop.getResult(res.index())});
}
}
}
}
rewriter.replaceOp(op, {newop.getResults()});
return success();
}
};
} // namespace
void populateRTBufferizePatterns(BufferizeTypeConverter &typeConverter,
RewritePatternSet &patterns) {
patterns.add<BufferizeDataflowYieldOp, BufferizeDataflowTaskOp>(
typeConverter, patterns.getContext());
}
namespace {
// For documentation see Autopar.td
struct BufferizeDataflowTaskOpsPass
: public BufferizeDataflowTaskOpsBase<BufferizeDataflowTaskOpsPass> {
void runOnOperation() override {
auto module = getOperation();
auto *context = &getContext();
BufferizeTypeConverter typeConverter;
RewritePatternSet patterns(context);
ConversionTarget target(*context);
populateRTBufferizePatterns(typeConverter, patterns);
// Forbid all RT ops that still use/return tensors
target.addDynamicallyLegalDialect<RT::RTDialect>(
[&](Operation *op) { return typeConverter.isLegal(op); });
if (failed(applyPartialConversion(module, target, std::move(patterns))))
signalPassFailure();
}
BufferizeDataflowTaskOpsPass(bool debug) : debug(debug){};
protected:
bool debug;
};
} // end anonymous namespace
std::unique_ptr<mlir::Pass> createBufferizeDataflowTaskOpsPass(bool debug) {
return std::make_unique<BufferizeDataflowTaskOpsPass>(debug);
}
} // namespace zamalang
} // namespace mlir

View File

@@ -0,0 +1,266 @@
#include <iostream>
#include <mlir/IR/BuiltinOps.h>
#include <zamalang/Dialect/HLFHE/IR/HLFHEDialect.h>
#include <zamalang/Dialect/HLFHE/IR/HLFHEOps.h>
#include <zamalang/Dialect/HLFHE/IR/HLFHETypes.h>
#include <zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.h>
#include <zamalang/Dialect/RT/Analysis/Autopar.h>
#include <zamalang/Dialect/RT/IR/RTDialect.h>
#include <zamalang/Dialect/RT/IR/RTOps.h>
#include <zamalang/Dialect/RT/IR/RTTypes.h>
#include <zamalang/Support/Constants.h>
#include <zamalang/Support/math.h>
#include <mlir/Dialect/MemRef/IR/MemRef.h>
#include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <mlir/IR/Attributes.h>
#include <mlir/IR/BlockAndValueMapping.h>
#include <mlir/IR/Builders.h>
#include <mlir/IR/BuiltinAttributes.h>
#include <mlir/IR/PatternMatch.h>
#include <mlir/Support/LLVM.h>
#include <mlir/Support/LogicalResult.h>
#include <mlir/Transforms/DialectConversion.h>
#include <mlir/Transforms/GreedyPatternRewriteDriver.h>
#include <mlir/Transforms/Passes.h>
#include <mlir/Transforms/RegionUtils.h>
#include <mlir/Transforms/Utils.h>
#define GEN_PASS_CLASSES
#include <zamalang/Dialect/RT/Analysis/Autopar.h.inc>
namespace mlir {
namespace zamalang {
namespace {
// TODO: adjust these two functions based on cost model
static bool isCandidateForTask(Operation *op) {
return isa<HLFHE::AddEintIntOp, HLFHE::AddEintOp, HLFHE::SubIntEintOp,
HLFHE::MulEintIntOp, HLFHE::ApplyLookupTableEintOp,
HLFHELinalg::MatMulIntEintOp, HLFHELinalg::MatMulEintIntOp,
HLFHELinalg::AddEintIntOp, HLFHELinalg::AddEintOp,
HLFHELinalg::SubIntEintOp, HLFHELinalg::NegEintOp,
HLFHELinalg::MulEintIntOp, HLFHELinalg::ApplyLookupTableEintOp,
HLFHELinalg::ApplyMultiLookupTableEintOp, HLFHELinalg::Dot>(op);
}
// Identify operations that are beneficial to sink into tasks. These
// operations must not have side-effects and not be `isCandidateForTask`
static bool isSinkingBeneficiary(Operation *op) {
return isa<HLFHE::ZeroEintOp, arith::ConstantOp, memref::DimOp, SelectOp,
mlir::arith::CmpIOp>(op);
}
static bool
extractBeneficiaryOps(Operation *op, SetVector<Value> existingDependencies,
SetVector<Operation *> &beneficiaryOps,
llvm::SmallPtrSetImpl<Value> &availableValues) {
if (beneficiaryOps.count(op))
return true;
if (!isSinkingBeneficiary(op))
return false;
for (Value operand : op->getOperands()) {
// It is already visible in the kernel, keep going.
if (availableValues.count(operand))
continue;
// Else check whether it can be made available via sinking or already is a
// dependency.
Operation *definingOp = operand.getDefiningOp();
if ((!definingOp ||
!extractBeneficiaryOps(definingOp, existingDependencies,
beneficiaryOps, availableValues)) &&
!existingDependencies.count(operand))
return false;
}
// We will sink the operation, mark its results as now available.
beneficiaryOps.insert(op);
for (Value result : op->getResults())
availableValues.insert(result);
return true;
}
LogicalResult sinkOperationsIntoDFTask(RT::DataflowTaskOp taskOp) {
Region &taskOpBody = taskOp.body();
// Identify uses from values defined outside of the scope.
SetVector<Value> sinkCandidates;
getUsedValuesDefinedAbove(taskOpBody, sinkCandidates);
SetVector<Operation *> toBeSunk;
llvm::SmallPtrSet<Value, 4> availableValues;
for (Value operand : sinkCandidates) {
Operation *operandOp = operand.getDefiningOp();
if (!operandOp)
continue;
extractBeneficiaryOps(operandOp, sinkCandidates, toBeSunk, availableValues);
}
// Insert operations so that the defs get cloned before uses.
BlockAndValueMapping map;
OpBuilder builder(taskOpBody);
for (Operation *op : toBeSunk) {
OpBuilder::InsertionGuard guard(builder);
Operation *clonedOp = builder.clone(*op, map);
for (auto pair : llvm::zip(op->getResults(), clonedOp->getResults()))
replaceAllUsesInRegionWith(std::get<0>(pair), std::get<1>(pair),
taskOp.body());
// Once this is sunk, remove all operands of the DFT covered by this
for (auto result : op->getResults())
for (auto operand : llvm::enumerate(taskOp.getOperands()))
if (operand.value() == result) {
taskOp->eraseOperand(operand.index());
// Once removed, we assume there are no duplicates
break;
}
}
return success();
}
// For documentation see Autopar.td
struct BuildDataflowTaskGraphPass
: public BuildDataflowTaskGraphBase<BuildDataflowTaskGraphPass> {
void runOnOperation() override {
auto module = getOperation();
module.walk([&](mlir::FuncOp func) {
if (!func->getAttr("_dfr_work_function_attribute"))
func.walk(
[&](mlir::Operation *childOp) { this->processOperation(childOp); });
// Perform simplifications, in particular DCE here in case some
// of the operations sunk in tasks are no longer needed in the
// main function. If the function fails it only means that
// nothing was simplified. Doing this here - rather than later
// in the compilation pipeline - allows to take advantage of
// higher level semantics which we can attach to operations
// (e.g., NoSideEffect on HLFHE::ZeroEintOp).
IRRewriter rewriter(func->getContext());
(void)mlir::simplifyRegions(rewriter, func->getRegions());
});
}
BuildDataflowTaskGraphPass(bool debug) : debug(debug){};
protected:
void processOperation(mlir::Operation *op) {
if (isCandidateForTask(op)) {
BlockAndValueMapping map;
Region &opBody = getOperation().body();
OpBuilder builder(opBody);
// Create a DFTask for this operation
builder.setInsertionPointAfter(op);
auto dftop = builder.create<RT::DataflowTaskOp>(
op->getLoc(), op->getResultTypes(), op->getOperands());
// Add the operation to the task
OpBuilder tbbuilder(dftop.body());
Operation *clonedOp = tbbuilder.clone(*op, map);
// Add sinkable operations to the task
assert(!failed(sinkOperationsIntoDFTask(dftop)) &&
"Failing to sink operations into DFT");
// Add terminator
tbbuilder.create<RT::DataflowYieldOp>(dftop.getLoc(), mlir::TypeRange(),
op->getResults());
// Replace the uses of defined values
for (auto pair : llvm::zip(op->getResults(), clonedOp->getResults()))
replaceAllUsesInRegionWith(std::get<0>(pair), std::get<1>(pair),
dftop.body());
// Replace uses of the values defined by the task
for (auto pair : llvm::zip(op->getResults(), dftop->getResults()))
replaceAllUsesInRegionWith(std::get<0>(pair), std::get<1>(pair),
opBody);
// Once uses are re-targeted to the task, delete the operation
op->erase();
}
}
bool debug;
};
} // end anonymous namespace
std::unique_ptr<mlir::Pass> createBuildDataflowTaskGraphPass(bool debug) {
return std::make_unique<BuildDataflowTaskGraphPass>(debug);
}
namespace {
// Marker to avoid infinite recursion of the rewriting pattern
static const mlir::StringLiteral kTransformMarker =
"_internal_RT_FixDataflowTaskOpInputsPattern_marker__";
class FixDataflowTaskOpInputsPattern
: public mlir::OpRewritePattern<RT::DataflowTaskOp> {
public:
FixDataflowTaskOpInputsPattern(mlir::MLIRContext *context)
: mlir::OpRewritePattern<RT::DataflowTaskOp>(
context, ::mlir::zamalang::DEFAULT_PATTERN_BENEFIT) {}
LogicalResult
matchAndRewrite(RT::DataflowTaskOp op,
mlir::PatternRewriter &rewriter) const override {
mlir::OpBuilder::InsertionGuard guard(rewriter);
if (op->hasAttr(kTransformMarker))
return failure();
// Identify which values need to be passed as dependences to the
// task - this is very conservative and will add constants, index
// operations, etc. A simplification will occur later.
SetVector<Value> deps;
getUsedValuesDefinedAbove(op.body(), deps);
auto newop = rewriter.create<RT::DataflowTaskOp>(
op.getLoc(), op.getResultTypes(), deps.getArrayRef());
rewriter.mergeBlocks(op.getBody(), newop.getBody(),
newop.getBody()->getArguments());
rewriter.replaceOp(op, {newop.getResults()});
// Mark this as processed to prevent infinite loop
newop.getOperation()->setAttr(kTransformMarker, rewriter.getUnitAttr());
return success();
}
};
} // namespace
namespace {
// For documentation see Autopar.td
struct FixupDataflowTaskOpsPass
: public FixupDataflowTaskOpsBase<FixupDataflowTaskOpsPass> {
void runOnOperation() override {
auto module = getOperation();
auto *context = &getContext();
RewritePatternSet patterns(context);
patterns.add<FixDataflowTaskOpInputsPattern>(context);
if (mlir::applyPatternsAndFoldGreedily(module, std::move(patterns))
.failed())
signalPassFailure();
// Clear mark and sink any newly created constants or indexing
// operations, etc. to reduce the number of input dependences to
// the task
module->walk([](RT::DataflowTaskOp op) {
op.getOperation()->removeAttr(kTransformMarker);
assert(!failed(sinkOperationsIntoDFTask(op)) &&
"Failing to sink operations into DFT");
});
}
FixupDataflowTaskOpsPass(bool debug) : debug(debug){};
protected:
bool debug;
};
} // end anonymous namespace
std::unique_ptr<mlir::Pass> createFixupDataflowTaskOpsPass(bool debug) {
return std::make_unique<FixupDataflowTaskOpsPass>(debug);
}
} // end namespace zamalang
} // end namespace mlir

View File

@@ -0,0 +1,18 @@
add_mlir_library(RTDialectAnalysis
BufferizeDataflowTaskOps.cpp
BuildDataflowTaskGraph.cpp
LowerDataflowTasksToRT.cpp
LowerRTToLLVMDFRCallsConversionPatterns.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/zamalang/Dialect/RT
DEPENDS
RTDialect
AutoparPassIncGen
LINK_LIBS PUBLIC
MLIRIR
RTDialect)
target_link_libraries(RTDialectAnalysis PUBLIC MLIRIR)

View File

@@ -0,0 +1,337 @@
#include <iostream>
#include <mlir/IR/BuiltinOps.h>
#include <zamalang/Dialect/HLFHE/IR/HLFHEDialect.h>
#include <zamalang/Dialect/HLFHE/IR/HLFHEOps.h>
#include <zamalang/Dialect/HLFHE/IR/HLFHETypes.h>
#include <zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h>
#include <zamalang/Dialect/LowLFHE/IR/LowLFHEOps.h>
#include <zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h>
#include <zamalang/Dialect/RT/Analysis/Autopar.h>
#include <zamalang/Dialect/RT/IR/RTDialect.h>
#include <zamalang/Dialect/RT/IR/RTOps.h>
#include <zamalang/Dialect/RT/IR/RTTypes.h>
#include <zamalang/Support/math.h>
#include <llvm/IR/Instructions.h>
#include <mlir/Analysis/DataFlowAnalysis.h>
#include <mlir/Conversion/LLVMCommon/ConversionTarget.h>
#include <mlir/Conversion/LLVMCommon/Pattern.h>
#include <mlir/Conversion/LLVMCommon/VectorPattern.h>
#include <mlir/Dialect/LLVMIR/FunctionCallUtils.h>
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include <mlir/Dialect/MemRef/IR/MemRef.h>
#include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <mlir/Dialect/StandardOps/Transforms/FuncConversions.h>
#include <mlir/Dialect/StandardOps/Transforms/Passes.h>
#include <mlir/IR/Attributes.h>
#include <mlir/IR/BlockAndValueMapping.h>
#include <mlir/IR/Builders.h>
#include <mlir/IR/BuiltinAttributes.h>
#include <mlir/IR/SymbolTable.h>
#include <mlir/Pass/PassManager.h>
#include <mlir/Support/LLVM.h>
#include <mlir/Support/LogicalResult.h>
#include <mlir/Transforms/Bufferize.h>
#include <mlir/Transforms/DialectConversion.h>
#include <mlir/Transforms/Passes.h>
#include <mlir/Transforms/RegionUtils.h>
#include <mlir/Transforms/Utils.h>
#include <zamalang/Conversion/Utils/GenericOpTypeConversionPattern.h>
#define GEN_PASS_CLASSES
#include <zamalang/Dialect/RT/Analysis/Autopar.h.inc>
namespace mlir {
namespace zamalang {
namespace {
static FuncOp outlineWorkFunction(RT::DataflowTaskOp DFTOp,
StringRef workFunctionName) {
Location loc = DFTOp.getLoc();
OpBuilder builder(DFTOp.getContext());
Region &DFTOpBody = DFTOp.body();
OpBuilder::InsertionGuard guard(builder);
// Instead of outlining with the same operands/results, we pass all
// results as operands as well. For now we preserve the results'
// types, which will be changed to use an indirection when lowering.
SmallVector<Type, 4> operandTypes;
operandTypes.reserve(DFTOp.getNumOperands() + DFTOp.getNumResults());
for (Value operand : DFTOp.getOperands())
operandTypes.push_back(RT::PointerType::get(operand.getType()));
for (Value res : DFTOp.getResults())
operandTypes.push_back(RT::PointerType::get(res.getType()));
FunctionType type = FunctionType::get(DFTOp.getContext(), operandTypes, {});
auto outlinedFunc = builder.create<FuncOp>(loc, workFunctionName, type);
outlinedFunc->setAttr("_dfr_work_function_attribute", builder.getUnitAttr());
Region &outlinedFuncBody = outlinedFunc.body();
Block *outlinedEntryBlock = new Block;
outlinedEntryBlock->addArguments(type.getInputs());
outlinedFuncBody.push_back(outlinedEntryBlock);
BlockAndValueMapping map;
Block &entryBlock = outlinedFuncBody.front();
builder.setInsertionPointToStart(&entryBlock);
for (auto operand : llvm::enumerate(DFTOp.getOperands())) {
// Add deref of arguments and remap to operands in the body
auto derefdop =
builder.create<RT::DerefWorkFunctionArgumentPtrPlaceholderOp>(
DFTOp.getLoc(), operand.value().getType(),
entryBlock.getArgument(operand.index()));
map.map(operand.value(), derefdop->getResult(0));
}
DFTOpBody.cloneInto(&outlinedFuncBody, map);
Block &DFTOpEntry = DFTOpBody.front();
Block *clonedDFTOpEntry = map.lookup(&DFTOpEntry);
builder.setInsertionPointToEnd(&entryBlock);
builder.create<BranchOp>(loc, clonedDFTOpEntry);
// TODO: we use a WorkFunctionReturnOp to tie return to the
// corresponding argument. This can be lowered to a copy/deref for
// shared memory and pointers, but needs to be handled for
// distributed memory.
outlinedFunc.walk([&](RT::DataflowYieldOp op) {
OpBuilder replacer(op);
int output_offset = DFTOp.getNumOperands();
for (auto ret : llvm::enumerate(op.getOperands()))
replacer.create<RT::WorkFunctionReturnOp>(
op.getLoc(), ret.value(),
outlinedFunc.getArgument(ret.index() + output_offset));
replacer.create<ReturnOp>(op.getLoc());
op.erase();
});
return outlinedFunc;
}
static void replaceAllUsesInDFTsInRegionWith(Value orig, Value replacement,
Region &region) {
for (auto &use : llvm::make_early_inc_range(orig.getUses())) {
if (isa<RT::DataflowTaskOp>(use.getOwner()) &&
region.isAncestor(use.getOwner()->getParentRegion()))
use.set(replacement);
}
}
static void replaceAllUsesNotInDFTsInRegionWith(Value orig, Value replacement,
Region &region) {
for (auto &use : llvm::make_early_inc_range(orig.getUses())) {
if (!isa<RT::DataflowTaskOp>(use.getOwner()) &&
use.getOwner()->getParentOfType<RT::DataflowTaskOp>() == nullptr &&
region.isAncestor(use.getOwner()->getParentRegion()))
use.set(replacement);
}
}
// TODO: Fix type sizes. For now we're using some default values.
static mlir::Value getSizeInBytes(Value val, Location loc, OpBuilder builder) {
DataLayout dataLayout = DataLayout::closest(val.getDefiningOp());
Type type = (val.getType().isa<RT::FutureType>())
? val.getType().dyn_cast<RT::FutureType>().getElementType()
: val.getType();
// In the case of memref, we need to determine how much space
// (conservatively) we need to store the memref itself. Overshooting
// by a few bytes should not be an issue, so the main thing is to
// properly account for the rank.
if (type.isa<mlir::MemRefType>()) {
// Space for the allocated and aligned pointers, and offset
Value ptrs_offset =
builder.create<arith::ConstantOp>(loc, builder.getI64IntegerAttr(24));
// For the sizes and shapes arrays, we need 2*8 = 16 times the rank in bytes
Value multiplier =
builder.create<arith::ConstantOp>(loc, builder.getI64IntegerAttr(16));
unsigned _rank = type.dyn_cast<mlir::MemRefType>().getRank();
Value rank = builder.create<arith::ConstantOp>(
loc, builder.getI64IntegerAttr(_rank));
Value sizes_shapes = builder.create<LLVM::MulOp>(loc, rank, multiplier);
Value result = builder.create<LLVM::AddOp>(loc, ptrs_offset, sizes_shapes);
return result;
}
// Unranked memrefs should be lowered to just pointer + size, so we need 16
// bytes.
if (type.isa<mlir::UnrankedMemRefType>())
return builder.create<arith::ConstantOp>(loc,
builder.getI64IntegerAttr(16));
// FHE types are converted to pointers, so we take their size as 8
// bytes until we can get the actual size of the actual types.
if (type.isa<mlir::zamalang::LowLFHE::ContextType>() ||
type.isa<mlir::zamalang::LowLFHE::LweCiphertextType>() ||
type.isa<mlir::zamalang::LowLFHE::GlweCiphertextType>() ||
type.isa<mlir::zamalang::LowLFHE::LweKeySwitchKeyType>() ||
type.isa<mlir::zamalang::LowLFHE::LweBootstrapKeyType>() ||
type.isa<mlir::zamalang::LowLFHE::ForeignPlaintextListType>() ||
type.isa<mlir::zamalang::LowLFHE::PlaintextListType>())
return builder.create<arith::ConstantOp>(loc, builder.getI64IntegerAttr(8));
// For all other types, get type size.
return builder.create<arith::ConstantOp>(
loc, builder.getI64IntegerAttr(dataLayout.getTypeSize(type)));
}
static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp, FuncOp workFunction) {
DataLayout dataLayout = DataLayout::closest(DFTOp);
Region &opBody = DFTOp->getParentOfType<FuncOp>().body();
BlockAndValueMapping map;
OpBuilder builder(DFTOp);
// First identify DFT operands that are not futures and are not
// defined by another DFT. These need to be made into futures and
// propagated to all other DFTs. We can allow PRE to eliminate the
// previous definitions if there are no non-future type uses.
builder.setInsertionPoint(DFTOp);
for (Value val : DFTOp.getOperands()) {
if (!val.getType().isa<RT::FutureType>()) {
Type futType = RT::FutureType::get(val.getType());
auto mrf =
builder.create<RT::MakeReadyFutureOp>(DFTOp.getLoc(), futType, val);
map.map(mrf->getResult(0), val);
replaceAllUsesInDFTsInRegionWith(val, mrf->getResult(0), opBody);
}
}
// Second generate a CreateAsyncTaskOp that will replace the
// DataflowTaskOp. This also includes the necessary handling of
// operands and results (conversion to/from futures and propagation).
SmallVector<Value, 4> catOperands;
int size = 3 + DFTOp.getNumResults() * 2 + DFTOp.getNumOperands() * 2;
catOperands.reserve(size);
auto fnptr = builder.create<mlir::ConstantOp>(
DFTOp.getLoc(), workFunction.getType(),
SymbolRefAttr::get(builder.getContext(), workFunction.getName()));
auto numIns = builder.create<arith::ConstantOp>(
DFTOp.getLoc(), builder.getI64IntegerAttr(DFTOp.getNumOperands()));
auto numOuts = builder.create<arith::ConstantOp>(
DFTOp.getLoc(), builder.getI64IntegerAttr(DFTOp.getNumResults()));
catOperands.push_back(fnptr.getResult());
catOperands.push_back(numIns.getResult());
catOperands.push_back(numOuts.getResult());
for (auto operand : DFTOp.getOperands()) {
catOperands.push_back(operand);
catOperands.push_back(getSizeInBytes(operand, DFTOp.getLoc(), builder));
}
// We need to adjust the results for the CreateAsyncTaskOp which
// are the work function's returns through pointers passed as
// parameters. As this is not supported within MLIR - and mostly
// unsupported even in the LLVMIR Dialect - this needs to use two
// placeholders for each output, before and after the
// CreateAsyncTaskOp.
for (auto result : DFTOp.getResults()) {
Type futType = RT::PointerType::get(RT::FutureType::get(result.getType()));
auto brpp = builder.create<RT::BuildReturnPtrPlaceholderOp>(DFTOp.getLoc(),
futType);
map.map(result, brpp->getResult(0));
catOperands.push_back(brpp->getResult(0));
catOperands.push_back(getSizeInBytes(result, DFTOp.getLoc(), builder));
}
builder.create<RT::CreateAsyncTaskOp>(
DFTOp.getLoc(),
SymbolRefAttr::get(builder.getContext(), workFunction.getName()),
catOperands);
// Third identify results of this DFT that are not used *only* in
// other DFTs as those will need to be waited on explicitly.
// We also create the DerefReturnPtrPlaceholderOp after the
// CreateAsyncTaskOp. These also need propagating.
for (auto result : DFTOp.getResults()) {
Type futType = RT::FutureType::get(result.getType());
Value futptr = map.lookupOrNull(result);
assert(futptr);
auto drpp = builder.create<RT::DerefReturnPtrPlaceholderOp>(
DFTOp.getLoc(), futType, futptr);
replaceAllUsesInDFTsInRegionWith(result, drpp->getResult(0), opBody);
for (auto &use : llvm::make_early_inc_range(result.getUses())) {
if (!isa<RT::DataflowTaskOp>(use.getOwner()) &&
use.getOwner()->getParentOfType<RT::DataflowTaskOp>() == nullptr) {
// Wait for this future
// TODO: the wait function should ideally
// be issued as late as possible, but need to identify which
// use comes first.
auto af = builder.create<RT::AwaitFutureOp>(
DFTOp.getLoc(), result.getType(), drpp.getResult());
replaceAllUsesNotInDFTsInRegionWith(result, af->getResult(0), opBody);
// We only need to to this once, propagation will hit all
// other uses
break;
}
}
// All leftover uses (i.e. those within DFTs should use the future)
replaceAllUsesInRegionWith(result, futptr, opBody);
}
// Finally erase the DFT.
DFTOp.erase();
}
// For documentation see Autopar.td
struct LowerDataflowTasksPass
: public LowerDataflowTasksBase<LowerDataflowTasksPass> {
void runOnOperation() override {
auto module = getOperation();
module.walk([&](mlir::FuncOp func) {
int wfn_id = 0;
// TODO: For now do not attempt to use nested parallelism.
if (func->getAttr("_dfr_work_function_attribute"))
return;
SymbolTable symbolTable = mlir::SymbolTable::getNearestSymbolTable(func);
std::vector<std::pair<RT::DataflowTaskOp, FuncOp>> outliningMap;
func.walk([&](RT::DataflowTaskOp op) {
auto workFunctionName = Twine("_dfr_DFT_work_function__") +
Twine(op->getParentOfType<FuncOp>().getName()) +
Twine(wfn_id++);
FuncOp outlinedFunc = outlineWorkFunction(op, workFunctionName.str());
outliningMap.push_back(
std::pair<RT::DataflowTaskOp, FuncOp>(op, outlinedFunc));
symbolTable.insert(outlinedFunc);
return WalkResult::advance();
});
// Lower the DF task ops to RT dialect ops.
for (auto mapping : outliningMap)
lowerDataflowTaskOp(mapping.first, mapping.second);
// Issue _dfr_start/stop calls for this function
if (!outliningMap.empty()) {
OpBuilder builder(func.body());
builder.setInsertionPointToStart(&func.body().front());
auto dfrStartFunOp = mlir::LLVM::lookupOrCreateFn(
func->getParentOfType<ModuleOp>(), "_dfr_start", {},
LLVM::LLVMVoidType::get(func->getContext()));
builder.create<LLVM::CallOp>(func.getLoc(), dfrStartFunOp,
mlir::ValueRange(),
ArrayRef<NamedAttribute>());
builder.setInsertionPoint(func.body().back().getTerminator());
auto dfrStopFunOp = mlir::LLVM::lookupOrCreateFn(
func->getParentOfType<ModuleOp>(), "_dfr_stop", {},
LLVM::LLVMVoidType::get(func->getContext()));
builder.create<LLVM::CallOp>(func.getLoc(), dfrStopFunOp,
mlir::ValueRange(),
ArrayRef<NamedAttribute>());
}
});
}
LowerDataflowTasksPass(bool debug) : debug(debug){};
protected:
bool debug;
};
} // end anonymous namespace
std::unique_ptr<mlir::Pass> createLowerDataflowTasksPass(bool debug) {
return std::make_unique<LowerDataflowTasksPass>(debug);
}
} // end namespace zamalang
} // end namespace mlir

View File

@@ -0,0 +1,310 @@
#include <iostream>
#include <mlir/IR/BuiltinOps.h>
#include <zamalang/Dialect/HLFHE/IR/HLFHEDialect.h>
#include <zamalang/Dialect/HLFHE/IR/HLFHEOps.h>
#include <zamalang/Dialect/HLFHE/IR/HLFHETypes.h>
#include <zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h>
#include <zamalang/Dialect/LowLFHE/IR/LowLFHEOps.h>
#include <zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h>
#include <zamalang/Dialect/RT/Analysis/Autopar.h>
#include <zamalang/Dialect/RT/IR/RTDialect.h>
#include <zamalang/Dialect/RT/IR/RTOps.h>
#include <zamalang/Dialect/RT/IR/RTTypes.h>
#include <zamalang/Support/math.h>
#include <llvm/IR/Instructions.h>
#include <llvm/Support/Compiler.h>
#include <mlir/Analysis/DataFlowAnalysis.h>
#include <mlir/Conversion/LLVMCommon/ConversionTarget.h>
#include <mlir/Conversion/LLVMCommon/Pattern.h>
#include <mlir/Conversion/LLVMCommon/VectorPattern.h>
#include <mlir/Dialect/LLVMIR/FunctionCallUtils.h>
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include <mlir/Dialect/MemRef/IR/MemRef.h>
#include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <mlir/Dialect/StandardOps/Transforms/FuncConversions.h>
#include <mlir/Dialect/StandardOps/Transforms/Passes.h>
#include <mlir/IR/Attributes.h>
#include <mlir/IR/BlockAndValueMapping.h>
#include <mlir/IR/Builders.h>
#include <mlir/IR/BuiltinAttributes.h>
#include <mlir/IR/SymbolTable.h>
#include <mlir/Pass/PassManager.h>
#include <mlir/Support/LLVM.h>
#include <mlir/Support/LogicalResult.h>
#include <mlir/Transforms/Bufferize.h>
#include <mlir/Transforms/DialectConversion.h>
#include <mlir/Transforms/Passes.h>
#include <mlir/Transforms/RegionUtils.h>
#include <mlir/Transforms/Utils.h>
#include <zamalang/Conversion/Utils/GenericOpTypeConversionPattern.h>
#define GEN_PASS_CLASSES
#include <zamalang/Dialect/RT/Analysis/Autopar.h.inc>
namespace mlir {
namespace zamalang {
namespace {
mlir::Type getVoidPtrI64Type(ConversionPatternRewriter &rewriter) {
return mlir::LLVM::LLVMPointerType::get(
mlir::IntegerType::get(rewriter.getContext(), 64));
}
LLVM::LLVMFuncOp getOrInsertFuncOpDecl(mlir::Operation *op,
llvm::StringRef funcName,
LLVM::LLVMFunctionType funcType,
ConversionPatternRewriter &rewriter) {
// Check if the function is already in the symbol table
auto module = op->getParentOfType<ModuleOp>();
auto funcOp = module.lookupSymbol<LLVM::LLVMFuncOp>(funcName);
if (!funcOp) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
funcOp =
rewriter.create<LLVM::LLVMFuncOp>(op->getLoc(), funcName, funcType);
funcOp.setPrivate();
} else {
if (!funcOp.isPrivate()) {
op->emitError()
<< "the function \"" << funcName
<< "\" conflicts with the Dataflow Runtime API, please rename.";
return nullptr;
}
}
return funcOp;
}
// This function is only needed for debug purposes to inspect values
// in the generated code - it is therefore not generally in use.
LLVM_ATTRIBUTE_UNUSED void
insertPrintDebugCall(ConversionPatternRewriter &rewriter, mlir::Operation *op,
Value val) {
OpBuilder::InsertionGuard guard(rewriter);
auto printFnType = LLVM::LLVMFunctionType::get(
LLVM::LLVMVoidType::get(rewriter.getContext()), {}, /*isVariadic=*/true);
auto printFnOp =
getOrInsertFuncOpDecl(op, "_dfr_print_debug", printFnType, rewriter);
rewriter.create<LLVM::CallOp>(op->getLoc(), printFnOp, val);
}
struct MakeReadyFutureOpInterfaceLowering
: public ConvertOpToLLVMPattern<RT::MakeReadyFutureOp> {
using ConvertOpToLLVMPattern<RT::MakeReadyFutureOp>::ConvertOpToLLVMPattern;
mlir::LogicalResult
matchAndRewrite(RT::MakeReadyFutureOp mrfOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
RT::MakeReadyFutureOp::Adaptor transformed(operands);
OpBuilder::InsertionGuard guard(rewriter);
// Normally this function takes a pointer as parameter
auto mrfFuncType = LLVM::LLVMFunctionType::get(getVoidPtrI64Type(rewriter),
{}, /*isVariadic=*/true);
auto mrfFuncOp = getOrInsertFuncOpDecl(mrfOp, "_dfr_make_ready_future",
mrfFuncType, rewriter);
// In order to support non pointer types, we need to allocate
// explicitly space that we can reference as a base for the
// future.
auto allocFuncOp = mlir::LLVM::lookupOrCreateMallocFn(
mrfOp->getParentOfType<ModuleOp>(), getIndexType());
auto sizeBytes = getSizeInBytes(
mrfOp.getLoc(), transformed.getOperands().getTypes().front(), rewriter);
auto results = mlir::LLVM::createLLVMCall(
rewriter, mrfOp.getLoc(), allocFuncOp, {sizeBytes}, getVoidPtrType());
Value allocatedPtr = rewriter.create<mlir::LLVM::BitcastOp>(
mrfOp.getLoc(),
mlir::LLVM::LLVMPointerType::get(
transformed.getOperands().getTypes().front()),
results[0]);
rewriter.create<LLVM::StoreOp>(
mrfOp.getLoc(), transformed.getOperands().front(), allocatedPtr);
rewriter.replaceOpWithNewOp<LLVM::CallOp>(mrfOp, mrfFuncOp, allocatedPtr);
return mlir::success();
}
};
struct AwaitFutureOpInterfaceLowering
: public ConvertOpToLLVMPattern<RT::AwaitFutureOp> {
using ConvertOpToLLVMPattern<RT::AwaitFutureOp>::ConvertOpToLLVMPattern;
mlir::LogicalResult
matchAndRewrite(RT::AwaitFutureOp afOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
RT::AwaitFutureOp::Adaptor transformed(operands);
OpBuilder::InsertionGuard guard(rewriter);
auto afFuncType = LLVM::LLVMFunctionType::get(
mlir::LLVM::LLVMPointerType::get(getVoidPtrI64Type(rewriter)),
{getVoidPtrI64Type(rewriter)});
auto afFuncOp =
getOrInsertFuncOpDecl(afOp, "_dfr_await_future", afFuncType, rewriter);
auto afCallOp = rewriter.create<LLVM::CallOp>(afOp.getLoc(), afFuncOp,
transformed.getOperands());
Value futVal = rewriter.create<mlir::LLVM::BitcastOp>(
afOp.getLoc(),
mlir::LLVM::LLVMPointerType::get(
(*getTypeConverter()).convertType(afOp.getResult().getType())),
afCallOp.getResult(0));
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(afOp, futVal);
return success();
}
};
struct CreateAsyncTaskOpInterfaceLowering
: public ConvertOpToLLVMPattern<RT::CreateAsyncTaskOp> {
using ConvertOpToLLVMPattern<RT::CreateAsyncTaskOp>::ConvertOpToLLVMPattern;
mlir::LogicalResult
matchAndRewrite(RT::CreateAsyncTaskOp catOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
RT::CreateAsyncTaskOp::Adaptor transformed(operands);
auto catFuncType =
LLVM::LLVMFunctionType::get(getVoidType(), {}, /*isVariadic=*/true);
auto catFuncOp = getOrInsertFuncOpDecl(catOp, "_dfr_create_async_task",
catFuncType, rewriter);
rewriter.replaceOpWithNewOp<LLVM::CallOp>(catOp, catFuncOp,
transformed.getOperands());
return success();
}
};
struct DeallocateFutureOpInterfaceLowering
: public ConvertOpToLLVMPattern<RT::DeallocateFutureOp> {
using ConvertOpToLLVMPattern<RT::DeallocateFutureOp>::ConvertOpToLLVMPattern;
mlir::LogicalResult
matchAndRewrite(RT::DeallocateFutureOp dfOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
RT::DeallocateFutureOp::Adaptor transformed(operands);
auto dfFuncType = LLVM::LLVMFunctionType::get(
getVoidType(), {getVoidPtrI64Type(rewriter)});
auto dfFuncOp = getOrInsertFuncOpDecl(dfOp, "_dfr_deallocate_future",
dfFuncType, rewriter);
rewriter.replaceOpWithNewOp<LLVM::CallOp>(dfOp, dfFuncOp,
transformed.getOperands());
return success();
}
};
struct DeallocateFutureDataOpInterfaceLowering
: public ConvertOpToLLVMPattern<RT::DeallocateFutureDataOp> {
using ConvertOpToLLVMPattern<
RT::DeallocateFutureDataOp>::ConvertOpToLLVMPattern;
mlir::LogicalResult
matchAndRewrite(RT::DeallocateFutureDataOp dfdOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
RT::DeallocateFutureDataOp::Adaptor transformed(operands);
auto dfdFuncType = LLVM::LLVMFunctionType::get(
getVoidType(), {getVoidPtrI64Type(rewriter)});
auto dfdFuncOp = getOrInsertFuncOpDecl(dfdOp, "_dfr_deallocate_future_data",
dfdFuncType, rewriter);
rewriter.replaceOpWithNewOp<LLVM::CallOp>(dfdOp, dfdFuncOp,
transformed.getOperands());
return success();
}
};
struct BuildReturnPtrPlaceholderOpInterfaceLowering
: public ConvertOpToLLVMPattern<RT::BuildReturnPtrPlaceholderOp> {
using ConvertOpToLLVMPattern<
RT::BuildReturnPtrPlaceholderOp>::ConvertOpToLLVMPattern;
mlir::LogicalResult
matchAndRewrite(RT::BuildReturnPtrPlaceholderOp befOp,
ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
OpBuilder::InsertionGuard guard(rewriter);
// BuildReturnPtrPlaceholder is a placeholder for generating a memory
// location where a pointer to allocated memory can be written so
// that we can return outputs from task work function.
Value one = rewriter.create<arith::ConstantOp>(
befOp.getLoc(),
(*getTypeConverter()).convertType(rewriter.getIndexType()),
rewriter.getIntegerAttr(
(*getTypeConverter()).convertType(rewriter.getIndexType()), 1));
rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(
befOp, mlir::LLVM::LLVMPointerType::get(getVoidPtrI64Type(rewriter)),
one,
/*alignment=*/
rewriter.getIntegerAttr(
(*getTypeConverter()).convertType(rewriter.getIndexType()), 0));
return success();
}
};
struct DerefReturnPtrPlaceholderOpInterfaceLowering
: public ConvertOpToLLVMPattern<RT::DerefReturnPtrPlaceholderOp> {
using ConvertOpToLLVMPattern<
RT::DerefReturnPtrPlaceholderOp>::ConvertOpToLLVMPattern;
mlir::LogicalResult
matchAndRewrite(RT::DerefReturnPtrPlaceholderOp drppOp,
ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
RT::DerefReturnPtrPlaceholderOp::Adaptor transformed(operands);
// DerefReturnPtrPlaceholder is a placeholder for generating a
// dereference operation for the pointer used to get results from
// task.
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
drppOp, transformed.getOperands().front());
return success();
}
};
struct DerefWorkFunctionArgumentPtrPlaceholderOpInterfaceLowering
: public ConvertOpToLLVMPattern<
RT::DerefWorkFunctionArgumentPtrPlaceholderOp> {
using ConvertOpToLLVMPattern<
RT::DerefWorkFunctionArgumentPtrPlaceholderOp>::ConvertOpToLLVMPattern;
mlir::LogicalResult
matchAndRewrite(RT::DerefWorkFunctionArgumentPtrPlaceholderOp dwfappOp,
ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
RT::DerefWorkFunctionArgumentPtrPlaceholderOp::Adaptor transformed(
operands);
OpBuilder::InsertionGuard guard(rewriter);
// DerefWorkFunctionArgumentPtrPlaceholderOp is a placeholder for
// generating a dereference operation for the pointer used to pass
// arguments to the task.
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
dwfappOp, transformed.getOperands().front());
return success();
}
};
struct WorkFunctionReturnOpInterfaceLowering
: public ConvertOpToLLVMPattern<RT::WorkFunctionReturnOp> {
using ConvertOpToLLVMPattern<
RT::WorkFunctionReturnOp>::ConvertOpToLLVMPattern;
mlir::LogicalResult
matchAndRewrite(RT::WorkFunctionReturnOp wfrOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
RT::WorkFunctionReturnOp::Adaptor transformed(operands);
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(
wfrOp, transformed.getOperands().front(),
transformed.getOperands().back());
return success();
}
};
} // end anonymous namespace
} // namespace zamalang
} // namespace mlir
void mlir::zamalang::populateRTToLLVMConversionPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
// clang-format off
patterns.add<
MakeReadyFutureOpInterfaceLowering,
AwaitFutureOpInterfaceLowering,
BuildReturnPtrPlaceholderOpInterfaceLowering,
DerefReturnPtrPlaceholderOpInterfaceLowering,
DerefWorkFunctionArgumentPtrPlaceholderOpInterfaceLowering,
CreateAsyncTaskOpInterfaceLowering,
DeallocateFutureOpInterfaceLowering,
DeallocateFutureDataOpInterfaceLowering,
WorkFunctionReturnOpInterfaceLowering>(converter);
// clang-format on
}

View File

@@ -1 +1,2 @@
add_subdirectory(Analysis)
add_subdirectory(IR)

View File

@@ -1,3 +1,16 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/FunctionImplementation.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "zamalang/Dialect/RT/IR/RTDialect.h"
#include "zamalang/Dialect/RT/IR/RTOps.h"
#include "zamalang/Dialect/RT/IR/RTTypes.h"
@@ -24,7 +37,7 @@ void RTDialect::initialize() {
::mlir::Type RTDialect::parseType(::mlir::DialectAsmParser &parser) const {
mlir::Type type;
if (parser.parseOptionalKeyword("future").succeeded()) {
generatedTypeParser(this->getContext(), parser, "future", type);
generatedTypeParser(parser, "future", type);
return type;
}
return type;
@@ -35,4 +48,4 @@ void RTDialect::printType(::mlir::Type type,
if (generatedTypePrinter(type, printer).failed()) {
printer.printType(type);
}
}
}

View File

@@ -1,3 +1,7 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Region.h"
#include "mlir/IR/TypeUtilities.h"
@@ -6,3 +10,17 @@
#define GET_OP_CLASSES
#include "zamalang/Dialect/RT/IR/RTOps.cpp.inc"
using namespace mlir::zamalang::RT;
void DataflowTaskOp::build(
::mlir::OpBuilder &builder, ::mlir::OperationState &result,
::mlir::TypeRange resultTypes, ::mlir::ValueRange operands,
::llvm::ArrayRef<::mlir::NamedAttribute> attributes) {
result.addOperands(operands);
result.addAttributes(attributes);
Region *reg = result.addRegion();
Block *body = new Block();
reg->push_back(body);
result.addTypes(resultTypes);
}

View File

@@ -28,6 +28,7 @@ add_mlir_library(ZamalangSupport
LowLFHEUnparametrize
MLIRLowerableDialectsToLLVM
HLFHEDialectAnalysis
RTDialectAnalysis
MLIRExecutionEngine
${LLVM_PTHREAD_LIB}

View File

@@ -84,6 +84,8 @@ void CompilerEngine::setVerifyDiagnostics(bool v) {
this->verifyDiagnostics = v;
}
void CompilerEngine::setAutoParallelize(bool v) { this->autoParallelize = v; }
void CompilerEngine::setGenerateClientParameters(bool v) {
this->generateClientParameters = v;
}
@@ -215,6 +217,13 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
return errorDiag("Tiling of HLFHELinalg operations failed");
}
// Auto parallelization
if (this->autoParallelize &&
mlir::zamalang::pipeline::autopar(mlirContext, module, enablePass)
.failed()) {
return StreamStringError("Auto parallelization failed");
}
if (target == Target::HLFHE)
return std::move(res);

View File

@@ -15,6 +15,7 @@
#include <zamalang/Conversion/Passes.h>
#include <zamalang/Dialect/HLFHE/Analysis/MANP.h>
#include <zamalang/Dialect/HLFHELinalg/Transforms/Tiling.h>
#include <zamalang/Dialect/RT/Analysis/Autopar.h>
#include <zamalang/Support/Pipeline.h>
#include <zamalang/Support/logging.h>
#include <zamalang/Support/math.h>
@@ -102,6 +103,17 @@ getFHEConstraintsFromHLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
return ret;
}
mlir::LogicalResult autopar(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass) {
mlir::PassManager pm(&context);
pipelinePrinting("AutoPar", pm, context);
addPotentiallyNestedPass(
pm, mlir::zamalang::createBuildDataflowTaskGraphPass(), enablePass);
return pm.run(module.getOperation());
}
mlir::LogicalResult
tileMarkedHLFHELinalg(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass) {
@@ -190,8 +202,18 @@ lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module,
enablePass);
addPotentiallyNestedPass(pm, mlir::createSCFBufferizePass(), enablePass);
addPotentiallyNestedPass(pm, mlir::createFuncBufferizePass(), enablePass);
addPotentiallyNestedPass(
pm, mlir::zamalang::createBufferizeDataflowTaskOpsPass(), enablePass);
addPotentiallyNestedPass(pm, mlir::createFinalizingBufferizePass(),
enablePass);
// Lower Dataflow tasks to DRF
addPotentiallyNestedPass(pm, mlir::zamalang::createFixupDataflowTaskOpsPass(),
enablePass);
addPotentiallyNestedPass(pm, mlir::zamalang::createLowerDataflowTasksPass(),
enablePass);
addPotentiallyNestedPass(pm, mlir::createConvertLinalgToLoopsPass(),
enablePass);
addPotentiallyNestedPass(pm, mlir::createLowerToCFGPass(), enablePass);
// Convert to MLIR LLVM Dialect

View File

@@ -123,6 +123,11 @@ llvm::cl::opt<bool> splitInputFile(
"chunk independently"),
llvm::cl::init(false));
llvm::cl::opt<bool> autoParallelize(
"parallelize",
llvm::cl::desc("Generate (and execute if JIT) parallel code"),
llvm::cl::init(false));
llvm::cl::opt<std::string> jitFuncName(
"jit-funcname",
llvm::cl::desc("Name of the function to execute, default 'main'"),
@@ -229,7 +234,7 @@ mlir::LogicalResult processInputBuffer(
llvm::Optional<size_t> overrideMaxEintPrecision,
llvm::Optional<size_t> overrideMaxMANP, bool verifyDiagnostics,
llvm::Optional<llvm::ArrayRef<int64_t>> hlfhelinalgTileSizes,
llvm::raw_ostream &os,
bool autoParallelize, llvm::raw_ostream &os,
std::shared_ptr<mlir::zamalang::CompilerEngine::Library> outputLib) {
std::shared_ptr<mlir::zamalang::CompilationContext> ccx =
mlir::zamalang::CompilationContext::createShared();
@@ -237,6 +242,7 @@ mlir::LogicalResult processInputBuffer(
mlir::zamalang::JitCompilerEngine ce{ccx};
ce.setVerifyDiagnostics(verifyDiagnostics);
ce.setAutoParallelize(autoParallelize);
if (cmdline::passes.size() != 0) {
ce.setEnablePass([](mlir::Pass *pass) {
return std::any_of(
@@ -404,7 +410,8 @@ mlir::LogicalResult compilerMain(int argc, char **argv) {
std::move(inputBuffer), fileName, cmdline::action,
cmdline::jitFuncName, cmdline::jitArgs,
cmdline::assumeMaxEintPrecision, cmdline::assumeMaxMANP,
cmdline::verifyDiagnostics, hlfhelinalgTileSizes, os, outputLib);
cmdline::verifyDiagnostics, hlfhelinalgTileSizes,
cmdline::autoParallelize, os, outputLib);
};
auto &os = output->os();
auto res = mlir::failure();