feat(compiler): move the lowering of dataflow tasks to RT dialect before bufferization.

This commit is contained in:
Antoniu Pop
2022-08-19 10:52:49 +01:00
committed by Antoniu Pop
parent 26901a32da
commit 2cf80e76eb
20 changed files with 1453 additions and 1026 deletions

View File

@@ -0,0 +1,66 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include <mlir/Dialect/Func/IR/FuncOps.h>
#include <mlir/Dialect/Linalg/IR/Linalg.h>
#include <mlir/IR/Operation.h>
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
template <typename TypeConverterType>
struct FunctionConstantOpConversion
: public mlir::OpRewritePattern<mlir::func::ConstantOp> {
FunctionConstantOpConversion(mlir::MLIRContext *ctx,
TypeConverterType &converter,
mlir::PatternBenefit benefit = 1)
: ::mlir::OpRewritePattern<mlir::func::ConstantOp>(ctx, benefit),
converter(converter) {}
mlir::LogicalResult
matchAndRewrite(mlir::func::ConstantOp op,
mlir::PatternRewriter &rewriter) const override {
auto symTab = mlir::SymbolTable::getNearestSymbolTable(op);
auto funcOp = mlir::SymbolTable::lookupSymbolIn(symTab, op.getValue());
assert(funcOp &&
"Function symbol missing in symbol table for function constant op.");
mlir::FunctionType funType = mlir::cast<mlir::func::FuncOp>(funcOp)
.getFunctionType()
.cast<mlir::FunctionType>();
typename TypeConverterType::SignatureConversion result(
funType.getNumInputs());
mlir::SmallVector<mlir::Type, 1> newResults;
if (failed(converter.convertSignatureArgs(funType.getInputs(), result)) ||
failed(converter.convertTypes(funType.getResults(), newResults)))
return mlir::failure();
auto newType = mlir::FunctionType::get(
rewriter.getContext(), result.getConvertedTypes(), newResults);
rewriter.updateRootInPlace(op, [&] { op.getResult().setType(newType); });
return mlir::success();
}
static bool isLegal(mlir::func::ConstantOp fun,
TypeConverterType &converter) {
auto symTab = mlir::SymbolTable::getNearestSymbolTable(fun);
auto funcOp = mlir::SymbolTable::lookupSymbolIn(symTab, fun.getValue());
assert(funcOp &&
"Function symbol missing in symbol table for function constant op.");
mlir::FunctionType funType = mlir::cast<mlir::func::FuncOp>(funcOp)
.getFunctionType()
.cast<mlir::FunctionType>();
typename TypeConverterType::SignatureConversion result(
funType.getNumInputs());
mlir::SmallVector<mlir::Type, 1> newResults;
if (failed(converter.convertSignatureArgs(funType.getInputs(), result)) ||
failed(converter.convertTypes(funType.getResults(), newResults)))
return false;
auto newType = mlir::FunctionType::get(
fun.getContext(), result.getConvertedTypes(), newResults);
return newType == fun.getType();
}
private:
TypeConverterType &converter;
};

View File

@@ -22,7 +22,8 @@ createBuildDataflowTaskGraphPass(bool debug = false);
std::unique_ptr<mlir::Pass> createLowerDataflowTasksPass(bool debug = false);
std::unique_ptr<mlir::Pass>
createBufferizeDataflowTaskOpsPass(bool debug = false);
std::unique_ptr<mlir::Pass> createFixupDataflowTaskOpsPass(bool debug = false);
std::unique_ptr<mlir::Pass> createFinalizeTaskCreationPass(bool debug = false);
std::unique_ptr<mlir::Pass> createStartStopPass(bool debug = false);
std::unique_ptr<mlir::Pass>
createFixupBufferDeallocationPass(bool debug = false);
void populateRTToLLVMConversionPatterns(mlir::LLVMTypeConverter &converter,

View File

@@ -53,23 +53,6 @@ def BufferizeDataflowTaskOps : Pass<"BufferizeDataflowTaskOps", "mlir::ModuleOp"
}];
}
def FixupDataflowTaskOps : Pass<"FixupDataflowTaskOps", "mlir::ModuleOp"> {
let summary =
"Fix DataflowTaskOp(s) before lowering.";
let description = [{
This pass fixes up code changes that intervene between the
BuildDataflowTaskGraph pass and the lowering of the taskgraph to
LLVMIR and calls to the DFR runtime system.
In particular, some operations (e.g., constants, dimension
operations, etc.) can be used within the task while only defined
outside. In most cases cloning and sinking these operations in the
task is the simplest to avoid adding dependences.
}];
}
def LowerDataflowTasks : Pass<"LowerDataflowTasks", "mlir::ModuleOp"> {
let summary =
"Outline the body of a DataflowTaskOp into a separate function which will serve as a task work function and lower the task graph to RT.";
@@ -82,6 +65,30 @@ def LowerDataflowTasks : Pass<"LowerDataflowTasks", "mlir::ModuleOp"> {
}];
}
def FinalizeTaskCreation : Pass<"FinalizeTaskCreation", "mlir::ModuleOp"> {
let summary =
"Finalize the CreateAsyncTaskOp ops.";
let description = [{
This pass adds the lower level information missing in
CreateAsyncTaskOp, in particular the type sizes and if required
passing the runtime context.
}];
}
def StartStop : Pass<"StartStop", "mlir::ModuleOp"> {
let summary =
"Issue calls to start/stop the runtime system.";
let description = [{
This pass adds calls to _dfr_start and _dfr_stop which
respectively initialize/start and pause the runtime system. The
start function further distributes the evaluation keys to
compute nodes when required and the stop function clears the
execution context.
}];
}
def FixupBufferDeallocation : Pass<"FixupBufferDeallocation", "mlir::ModuleOp"> {
let summary =
"Prevent deallocation of buffers returned as futures by tasks.";

View File

@@ -24,6 +24,7 @@ void _dfr_set_use_omp(bool);
bool _dfr_is_jit();
bool _dfr_is_root_node();
bool _dfr_use_omp();
bool _dfr_is_distributed();
typedef enum _dfr_task_arg_type {
_DFR_TASK_ARG_BASE = 0,

View File

@@ -69,23 +69,27 @@ struct OpaqueInputData {
std::vector<size_t> _param_sizes,
std::vector<uint64_t> _param_types,
std::vector<size_t> _output_sizes,
std::vector<uint64_t> _output_types)
std::vector<uint64_t> _output_types, void *_context = nullptr)
: wfn_name(_wfn_name), params(std::move(_params)),
param_sizes(std::move(_param_sizes)),
param_types(std::move(_param_types)),
output_sizes(std::move(_output_sizes)),
output_types(std::move(_output_types)) {}
output_types(std::move(_output_types)), context(_context) {
if (_context)
params.push_back(_context);
}
OpaqueInputData(const OpaqueInputData &oid)
: wfn_name(std::move(oid.wfn_name)), params(std::move(oid.params)),
param_sizes(std::move(oid.param_sizes)),
param_types(std::move(oid.param_types)),
output_sizes(std::move(oid.output_sizes)),
output_types(std::move(oid.output_types)) {}
output_types(std::move(oid.output_types)), context(oid.context) {}
friend class hpx::serialization::access;
template <class Archive> void load(Archive &ar, const unsigned int version) {
ar >> wfn_name;
bool has_context;
ar >> wfn_name >> has_context;
ar >> param_sizes >> param_types;
ar >> output_sizes >> output_types;
for (size_t p = 0; p < param_sizes.size(); ++p) {
@@ -114,27 +118,22 @@ struct OpaqueInputData {
static_cast<StridedMemRefType<char, 1> *>(params[p])->basePtr = nullptr;
static_cast<StridedMemRefType<char, 1> *>(params[p])->data = data;
} break;
case _DFR_TASK_ARG_CONTEXT: {
// The copied pointer is meaningless - TODO: if the context
// can change dynamically (e.g., different evaluation keys)
// then this needs updating by passing key ids and retrieving
// adequate keys for the context.
delete ((char *)params[p]);
params[p] =
(void *)_dfr_node_level_runtime_context_manager->getContext();
} break;
default:
HPX_THROW_EXCEPTION(hpx::no_success, "DFR: OpaqueInputData save",
"Error: invalid task argument type.");
}
}
if (has_context)
params.push_back(
(void *)_dfr_node_level_runtime_context_manager->getContext());
}
template <class Archive>
void save(Archive &ar, const unsigned int version) const {
ar << wfn_name;
bool has_context = (bool)(context != nullptr);
ar << wfn_name << has_context;
ar << param_sizes << param_types;
ar << output_sizes << output_types;
for (size_t p = 0; p < params.size(); ++p) {
for (size_t p = 0; p < param_sizes.size(); ++p) {
// Save the first level of the data structure - if the parameter
// is a tensor/memref, there is a second level.
ar << hpx::serialization::make_array((char *)params[p], param_sizes[p]);
@@ -152,10 +151,6 @@ struct OpaqueInputData {
ar << hpx::serialization::make_array(
mref.data + mref.offset * elementSize, size * elementSize);
} break;
case _DFR_TASK_ARG_CONTEXT: {
// Nothing to do now - TODO: pass key ids if these are not
// unique for a computation.
} break;
default:
HPX_THROW_EXCEPTION(hpx::no_success, "DFR: OpaqueInputData save",
"Error: invalid task argument type.");
@@ -170,6 +165,7 @@ struct OpaqueInputData {
std::vector<uint64_t> param_types;
std::vector<size_t> output_sizes;
std::vector<uint64_t> output_types;
void *context;
};
struct OpaqueOutputData {
@@ -214,9 +210,6 @@ struct OpaqueOutputData {
static_cast<StridedMemRefType<char, 1> *>(outputs[p])->basePtr =
nullptr;
static_cast<StridedMemRefType<char, 1> *>(outputs[p])->data = data;
} break;
case _DFR_TASK_ARG_CONTEXT: {
} break;
default:
HPX_THROW_EXCEPTION(hpx::no_success, "DFR: OpaqueInputData save",
@@ -243,9 +236,6 @@ struct OpaqueOutputData {
size *= mref.sizes[r];
ar << hpx::serialization::make_array(
mref.data + mref.offset * elementSize, size * elementSize);
} break;
case _DFR_TASK_ARG_CONTEXT: {
} break;
default:
HPX_THROW_EXCEPTION(hpx::no_success, "DFR: OpaqueInputData save",
@@ -282,121 +272,121 @@ struct GenericComputeServer : component_base<GenericComputeServer> {
wfn(output);
break;
case 1:
wfn(inputs.params[0], output);
wfn(output, inputs.params[0]);
break;
case 2:
wfn(inputs.params[0], inputs.params[1], output);
wfn(output, inputs.params[0], inputs.params[1]);
break;
case 3:
wfn(inputs.params[0], inputs.params[1], inputs.params[2], output);
wfn(output, inputs.params[0], inputs.params[1], inputs.params[2]);
break;
case 4:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], output);
wfn(output, inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3]);
break;
case 5:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], output);
wfn(output, inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4]);
break;
case 6:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5], output);
wfn(output, inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5]);
break;
case 7:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
wfn(output, inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], output);
inputs.params[6]);
break;
case 8:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
wfn(output, inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], output);
inputs.params[6], inputs.params[7]);
break;
case 9:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
wfn(output, inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8], output);
inputs.params[6], inputs.params[7], inputs.params[8]);
break;
case 10:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
wfn(output, inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8],
inputs.params[9], output);
inputs.params[9]);
break;
case 11:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
wfn(output, inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8],
inputs.params[9], inputs.params[10], output);
inputs.params[9], inputs.params[10]);
break;
case 12:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
wfn(output, inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8],
inputs.params[9], inputs.params[10], inputs.params[11], output);
inputs.params[9], inputs.params[10], inputs.params[11]);
break;
case 13:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
wfn(output, inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8],
inputs.params[9], inputs.params[10], inputs.params[11],
inputs.params[12], output);
inputs.params[12]);
break;
case 14:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
wfn(output, inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8],
inputs.params[9], inputs.params[10], inputs.params[11],
inputs.params[12], inputs.params[13], output);
inputs.params[12], inputs.params[13]);
break;
case 15:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
wfn(output, inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8],
inputs.params[9], inputs.params[10], inputs.params[11],
inputs.params[12], inputs.params[13], inputs.params[14], output);
inputs.params[12], inputs.params[13], inputs.params[14]);
break;
case 16:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
wfn(output, inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8],
inputs.params[9], inputs.params[10], inputs.params[11],
inputs.params[12], inputs.params[13], inputs.params[14],
inputs.params[15], output);
inputs.params[15]);
break;
case 17:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
wfn(output, inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8],
inputs.params[9], inputs.params[10], inputs.params[11],
inputs.params[12], inputs.params[13], inputs.params[14],
inputs.params[15], inputs.params[16], output);
inputs.params[15], inputs.params[16]);
break;
case 18:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
wfn(output, inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8],
inputs.params[9], inputs.params[10], inputs.params[11],
inputs.params[12], inputs.params[13], inputs.params[14],
inputs.params[15], inputs.params[16], inputs.params[17], output);
inputs.params[15], inputs.params[16], inputs.params[17]);
break;
case 19:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
wfn(output, inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8],
inputs.params[9], inputs.params[10], inputs.params[11],
inputs.params[12], inputs.params[13], inputs.params[14],
inputs.params[15], inputs.params[16], inputs.params[17],
inputs.params[18], output);
inputs.params[18]);
break;
case 20:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
wfn(output, inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8],
inputs.params[9], inputs.params[10], inputs.params[11],
inputs.params[12], inputs.params[13], inputs.params[14],
inputs.params[15], inputs.params[16], inputs.params[17],
inputs.params[18], inputs.params[19], output);
inputs.params[18], inputs.params[19]);
break;
default:
HPX_THROW_EXCEPTION(hpx::no_success,
@@ -415,127 +405,127 @@ struct GenericComputeServer : component_base<GenericComputeServer> {
wfn(output1, output2);
break;
case 1:
wfn(inputs.params[0], output1, output2);
wfn(output1, output2, inputs.params[0]);
break;
case 2:
wfn(inputs.params[0], inputs.params[1], output1, output2);
wfn(output1, output2, inputs.params[0], inputs.params[1]);
break;
case 3:
wfn(inputs.params[0], inputs.params[1], inputs.params[2], output1,
output2);
wfn(output1, output2, inputs.params[0], inputs.params[1],
inputs.params[2], output1, output2);
break;
case 4:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], output1, output2);
wfn(output1, output2, inputs.params[0], inputs.params[1],
inputs.params[2], inputs.params[3]);
break;
case 5:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], output1, output2);
wfn(output1, output2, inputs.params[0], inputs.params[1],
inputs.params[2], inputs.params[3], inputs.params[4]);
break;
case 6:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5], output1,
output2);
wfn(output1, output2, inputs.params[0], inputs.params[1],
inputs.params[2], inputs.params[3], inputs.params[4],
inputs.params[5], output1, output2);
break;
case 7:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], output1, output2);
wfn(output1, output2, inputs.params[0], inputs.params[1],
inputs.params[2], inputs.params[3], inputs.params[4],
inputs.params[5], inputs.params[6]);
break;
case 8:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], output1, output2);
wfn(output1, output2, inputs.params[0], inputs.params[1],
inputs.params[2], inputs.params[3], inputs.params[4],
inputs.params[5], inputs.params[6], inputs.params[7]);
break;
case 9:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8], output1,
output2);
wfn(output1, output2, inputs.params[0], inputs.params[1],
inputs.params[2], inputs.params[3], inputs.params[4],
inputs.params[5], inputs.params[6], inputs.params[7],
inputs.params[8], output1, output2);
break;
case 10:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8],
inputs.params[9], output1, output2);
wfn(output1, output2, inputs.params[0], inputs.params[1],
inputs.params[2], inputs.params[3], inputs.params[4],
inputs.params[5], inputs.params[6], inputs.params[7],
inputs.params[8], inputs.params[9]);
break;
case 11:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8],
inputs.params[9], inputs.params[10], output1, output2);
wfn(output1, output2, inputs.params[0], inputs.params[1],
inputs.params[2], inputs.params[3], inputs.params[4],
inputs.params[5], inputs.params[6], inputs.params[7],
inputs.params[8], inputs.params[9], inputs.params[10]);
break;
case 12:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8],
inputs.params[9], inputs.params[10], inputs.params[11], output1,
output2);
wfn(output1, output2, inputs.params[0], inputs.params[1],
inputs.params[2], inputs.params[3], inputs.params[4],
inputs.params[5], inputs.params[6], inputs.params[7],
inputs.params[8], inputs.params[9], inputs.params[10],
inputs.params[11], output1, output2);
break;
case 13:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8],
inputs.params[9], inputs.params[10], inputs.params[11],
inputs.params[12], output1, output2);
wfn(output1, output2, inputs.params[0], inputs.params[1],
inputs.params[2], inputs.params[3], inputs.params[4],
inputs.params[5], inputs.params[6], inputs.params[7],
inputs.params[8], inputs.params[9], inputs.params[10],
inputs.params[11], inputs.params[12]);
break;
case 14:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8],
inputs.params[9], inputs.params[10], inputs.params[11],
inputs.params[12], inputs.params[13], output1, output2);
wfn(output1, output2, inputs.params[0], inputs.params[1],
inputs.params[2], inputs.params[3], inputs.params[4],
inputs.params[5], inputs.params[6], inputs.params[7],
inputs.params[8], inputs.params[9], inputs.params[10],
inputs.params[11], inputs.params[12], inputs.params[13]);
break;
case 15:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8],
inputs.params[9], inputs.params[10], inputs.params[11],
inputs.params[12], inputs.params[13], inputs.params[14], output1,
output2);
wfn(output1, output2, inputs.params[0], inputs.params[1],
inputs.params[2], inputs.params[3], inputs.params[4],
inputs.params[5], inputs.params[6], inputs.params[7],
inputs.params[8], inputs.params[9], inputs.params[10],
inputs.params[11], inputs.params[12], inputs.params[13],
inputs.params[14], output1, output2);
break;
case 16:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8],
inputs.params[9], inputs.params[10], inputs.params[11],
inputs.params[12], inputs.params[13], inputs.params[14],
inputs.params[15], output1, output2);
wfn(output1, output2, inputs.params[0], inputs.params[1],
inputs.params[2], inputs.params[3], inputs.params[4],
inputs.params[5], inputs.params[6], inputs.params[7],
inputs.params[8], inputs.params[9], inputs.params[10],
inputs.params[11], inputs.params[12], inputs.params[13],
inputs.params[14], inputs.params[15]);
break;
case 17:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8],
inputs.params[9], inputs.params[10], inputs.params[11],
inputs.params[12], inputs.params[13], inputs.params[14],
inputs.params[15], inputs.params[16], output1, output2);
wfn(output1, output2, inputs.params[0], inputs.params[1],
inputs.params[2], inputs.params[3], inputs.params[4],
inputs.params[5], inputs.params[6], inputs.params[7],
inputs.params[8], inputs.params[9], inputs.params[10],
inputs.params[11], inputs.params[12], inputs.params[13],
inputs.params[14], inputs.params[15], inputs.params[16]);
break;
case 18:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8],
inputs.params[9], inputs.params[10], inputs.params[11],
inputs.params[12], inputs.params[13], inputs.params[14],
inputs.params[15], inputs.params[16], inputs.params[17], output1,
output2);
wfn(output1, output2, inputs.params[0], inputs.params[1],
inputs.params[2], inputs.params[3], inputs.params[4],
inputs.params[5], inputs.params[6], inputs.params[7],
inputs.params[8], inputs.params[9], inputs.params[10],
inputs.params[11], inputs.params[12], inputs.params[13],
inputs.params[14], inputs.params[15], inputs.params[16],
inputs.params[17], output1, output2);
break;
case 19:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8],
inputs.params[9], inputs.params[10], inputs.params[11],
inputs.params[12], inputs.params[13], inputs.params[14],
inputs.params[15], inputs.params[16], inputs.params[17],
inputs.params[18], output1, output2);
wfn(output1, output2, inputs.params[0], inputs.params[1],
inputs.params[2], inputs.params[3], inputs.params[4],
inputs.params[5], inputs.params[6], inputs.params[7],
inputs.params[8], inputs.params[9], inputs.params[10],
inputs.params[11], inputs.params[12], inputs.params[13],
inputs.params[14], inputs.params[15], inputs.params[16],
inputs.params[17], inputs.params[18]);
break;
case 20:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8],
inputs.params[9], inputs.params[10], inputs.params[11],
inputs.params[12], inputs.params[13], inputs.params[14],
inputs.params[15], inputs.params[16], inputs.params[17],
inputs.params[18], inputs.params[19], output1, output2);
wfn(output1, output2, inputs.params[0], inputs.params[1],
inputs.params[2], inputs.params[3], inputs.params[4],
inputs.params[5], inputs.params[6], inputs.params[7],
inputs.params[8], inputs.params[9], inputs.params[10],
inputs.params[11], inputs.params[12], inputs.params[13],
inputs.params[14], inputs.params[15], inputs.params[16],
inputs.params[17], inputs.params[18], inputs.params[19]);
break;
default:
HPX_THROW_EXCEPTION(hpx::no_success,
@@ -555,127 +545,127 @@ struct GenericComputeServer : component_base<GenericComputeServer> {
wfn(output1, output2, output3);
break;
case 1:
wfn(inputs.params[0], output1, output2, output3);
wfn(output1, output2, output3, inputs.params[0]);
break;
case 2:
wfn(inputs.params[0], inputs.params[1], output1, output2, output3);
wfn(output1, output2, output3, inputs.params[0], inputs.params[1]);
break;
case 3:
wfn(inputs.params[0], inputs.params[1], inputs.params[2], output1,
output2, output3);
wfn(output1, output2, output3, inputs.params[0], inputs.params[1],
inputs.params[2], output1, output2, output3);
break;
case 4:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], output1, output2, output3);
wfn(output1, output2, output3, inputs.params[0], inputs.params[1],
inputs.params[2], inputs.params[3]);
break;
case 5:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], output1, output2, output3);
wfn(output1, output2, output3, inputs.params[0], inputs.params[1],
inputs.params[2], inputs.params[3], inputs.params[4]);
break;
case 6:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5], output1,
output2, output3);
wfn(output1, output2, output3, inputs.params[0], inputs.params[1],
inputs.params[2], inputs.params[3], inputs.params[4],
inputs.params[5], output1, output2, output3);
break;
case 7:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], output1, output2, output3);
wfn(output1, output2, output3, inputs.params[0], inputs.params[1],
inputs.params[2], inputs.params[3], inputs.params[4],
inputs.params[5], inputs.params[6]);
break;
case 8:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], output1, output2, output3);
wfn(output1, output2, output3, inputs.params[0], inputs.params[1],
inputs.params[2], inputs.params[3], inputs.params[4],
inputs.params[5], inputs.params[6], inputs.params[7]);
break;
case 9:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8], output1,
output2, output3);
wfn(output1, output2, output3, inputs.params[0], inputs.params[1],
inputs.params[2], inputs.params[3], inputs.params[4],
inputs.params[5], inputs.params[6], inputs.params[7],
inputs.params[8], output1, output2, output3);
break;
case 10:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8],
inputs.params[9], output1, output2, output3);
wfn(output1, output2, output3, inputs.params[0], inputs.params[1],
inputs.params[2], inputs.params[3], inputs.params[4],
inputs.params[5], inputs.params[6], inputs.params[7],
inputs.params[8], inputs.params[9]);
break;
case 11:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8],
inputs.params[9], inputs.params[10], output1, output2, output3);
wfn(output1, output2, output3, inputs.params[0], inputs.params[1],
inputs.params[2], inputs.params[3], inputs.params[4],
inputs.params[5], inputs.params[6], inputs.params[7],
inputs.params[8], inputs.params[9], inputs.params[10]);
break;
case 12:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8],
inputs.params[9], inputs.params[10], inputs.params[11], output1,
output2, output3);
wfn(output1, output2, output3, inputs.params[0], inputs.params[1],
inputs.params[2], inputs.params[3], inputs.params[4],
inputs.params[5], inputs.params[6], inputs.params[7],
inputs.params[8], inputs.params[9], inputs.params[10],
inputs.params[11], output1, output2, output3);
break;
case 13:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8],
inputs.params[9], inputs.params[10], inputs.params[11],
inputs.params[12], output1, output2, output3);
wfn(output1, output2, output3, inputs.params[0], inputs.params[1],
inputs.params[2], inputs.params[3], inputs.params[4],
inputs.params[5], inputs.params[6], inputs.params[7],
inputs.params[8], inputs.params[9], inputs.params[10],
inputs.params[11], inputs.params[12]);
break;
case 14:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8],
inputs.params[9], inputs.params[10], inputs.params[11],
inputs.params[12], inputs.params[13], output1, output2, output3);
wfn(output1, output2, output3, inputs.params[0], inputs.params[1],
inputs.params[2], inputs.params[3], inputs.params[4],
inputs.params[5], inputs.params[6], inputs.params[7],
inputs.params[8], inputs.params[9], inputs.params[10],
inputs.params[11], inputs.params[12], inputs.params[13]);
break;
case 15:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8],
inputs.params[9], inputs.params[10], inputs.params[11],
inputs.params[12], inputs.params[13], inputs.params[14], output1,
output2, output3);
wfn(output1, output2, output3, inputs.params[0], inputs.params[1],
inputs.params[2], inputs.params[3], inputs.params[4],
inputs.params[5], inputs.params[6], inputs.params[7],
inputs.params[8], inputs.params[9], inputs.params[10],
inputs.params[11], inputs.params[12], inputs.params[13],
inputs.params[14], output1, output2, output3);
break;
case 16:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8],
inputs.params[9], inputs.params[10], inputs.params[11],
inputs.params[12], inputs.params[13], inputs.params[14],
inputs.params[15], output1, output2, output3);
wfn(output1, output2, output3, inputs.params[0], inputs.params[1],
inputs.params[2], inputs.params[3], inputs.params[4],
inputs.params[5], inputs.params[6], inputs.params[7],
inputs.params[8], inputs.params[9], inputs.params[10],
inputs.params[11], inputs.params[12], inputs.params[13],
inputs.params[14], inputs.params[15]);
break;
case 17:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8],
inputs.params[9], inputs.params[10], inputs.params[11],
inputs.params[12], inputs.params[13], inputs.params[14],
inputs.params[15], inputs.params[16], output1, output2, output3);
wfn(output1, output2, output3, inputs.params[0], inputs.params[1],
inputs.params[2], inputs.params[3], inputs.params[4],
inputs.params[5], inputs.params[6], inputs.params[7],
inputs.params[8], inputs.params[9], inputs.params[10],
inputs.params[11], inputs.params[12], inputs.params[13],
inputs.params[14], inputs.params[15], inputs.params[16]);
break;
case 18:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8],
inputs.params[9], inputs.params[10], inputs.params[11],
inputs.params[12], inputs.params[13], inputs.params[14],
inputs.params[15], inputs.params[16], inputs.params[17], output1,
output2, output3);
wfn(output1, output2, output3, inputs.params[0], inputs.params[1],
inputs.params[2], inputs.params[3], inputs.params[4],
inputs.params[5], inputs.params[6], inputs.params[7],
inputs.params[8], inputs.params[9], inputs.params[10],
inputs.params[11], inputs.params[12], inputs.params[13],
inputs.params[14], inputs.params[15], inputs.params[16],
inputs.params[17], output1, output2, output3);
break;
case 19:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8],
inputs.params[9], inputs.params[10], inputs.params[11],
inputs.params[12], inputs.params[13], inputs.params[14],
inputs.params[15], inputs.params[16], inputs.params[17],
inputs.params[18], output1, output2, output3);
wfn(output1, output2, output3, inputs.params[0], inputs.params[1],
inputs.params[2], inputs.params[3], inputs.params[4],
inputs.params[5], inputs.params[6], inputs.params[7],
inputs.params[8], inputs.params[9], inputs.params[10],
inputs.params[11], inputs.params[12], inputs.params[13],
inputs.params[14], inputs.params[15], inputs.params[16],
inputs.params[17], inputs.params[18]);
break;
case 20:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8],
inputs.params[9], inputs.params[10], inputs.params[11],
inputs.params[12], inputs.params[13], inputs.params[14],
inputs.params[15], inputs.params[16], inputs.params[17],
inputs.params[18], inputs.params[19], output1, output2, output3);
wfn(output1, output2, output3, inputs.params[0], inputs.params[1],
inputs.params[2], inputs.params[3], inputs.params[4],
inputs.params[5], inputs.params[6], inputs.params[7],
inputs.params[8], inputs.params[9], inputs.params[10],
inputs.params[11], inputs.params[12], inputs.params[13],
inputs.params[14], inputs.params[15], inputs.params[16],
inputs.params[17], inputs.params[18], inputs.params[19]);
break;
default:
HPX_THROW_EXCEPTION(hpx::no_success,
@@ -693,12 +683,10 @@ struct GenericComputeServer : component_base<GenericComputeServer> {
// Deallocate input data buffers from OID deserialization (load)
if (!_dfr_is_root_node()) {
for (size_t p = 0; p < inputs.param_sizes.size(); ++p) {
if (_dfr_get_arg_type(inputs.param_types[p]) != _DFR_TASK_ARG_CONTEXT) {
if (_dfr_get_arg_type(inputs.param_types[p]) == _DFR_TASK_ARG_MEMREF)
delete (static_cast<StridedMemRefType<char, 1> *>(inputs.params[p])
->data);
delete ((char *)inputs.params[p]);
}
if (_dfr_get_arg_type(inputs.param_types[p]) == _DFR_TASK_ARG_MEMREF)
delete (static_cast<StridedMemRefType<char, 1> *>(inputs.params[p])
->data);
delete ((char *)inputs.params[p]);
}
}

View File

@@ -180,7 +180,7 @@ struct RuntimeContextManager {
}
}
RuntimeContext **getContext() { return &context; }
RuntimeContext *getContext() { return context; }
void clearContext() {
if (context != nullptr)

View File

@@ -15,7 +15,7 @@ extern "C" {
typedef void (*wfnptr)(...);
void *_dfr_make_ready_future(void *, size_t);
void _dfr_create_async_task(wfnptr, size_t, size_t, ...);
void _dfr_create_async_task(wfnptr, void *, size_t, size_t, ...);
void _dfr_register_work_function(wfnptr);
void *_dfr_await_future(void *);
@@ -26,8 +26,7 @@ void _dfr_deallocate_future(void *);
void _dfr_deallocate_future_data(void *);
/* Initialisation & termination. */
void _dfr_start_c(int64_t, void *);
void _dfr_start(int64_t);
void _dfr_start(int64_t, void *);
void _dfr_stop(int64_t);
void _dfr_terminate();

View File

@@ -25,6 +25,7 @@
#include "mlir/Transforms/DialectConversion.h"
#include "concretelang/Conversion/Passes.h"
#include "concretelang/Conversion/Utils/FuncConstOpConversion.h"
#include "concretelang/Conversion/Utils/RegionOpTypeConverterPattern.h"
#include "concretelang/Conversion/Utils/TensorOpTypeConversion.h"
#include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h"
@@ -108,6 +109,16 @@ public:
newShape, mlir::IntegerType::get(type.getContext(), 64));
return r;
});
addConversion([&](mlir::concretelang::RT::FutureType type) {
return mlir::concretelang::RT::FutureType::get(
this->convertType(type.dyn_cast<mlir::concretelang::RT::FutureType>()
.getElementType()));
});
addConversion([&](mlir::concretelang::RT::PointerType type) {
return mlir::concretelang::RT::PointerType::get(
this->convertType(type.dyn_cast<mlir::concretelang::RT::PointerType>()
.getElementType()));
});
}
};
@@ -956,6 +967,14 @@ void ConcreteToBConcretePass::runOnOperation() {
return converter.isSignatureLegal(funcOp.getFunctionType()) &&
converter.isLegal(&funcOp.getBody());
});
target.addDynamicallyLegalOp<mlir::func::ConstantOp>(
[&](mlir::func::ConstantOp op) {
return FunctionConstantOpConversion<
ConcreteToBConcreteTypeConverter>::isLegal(op, converter);
});
patterns
.insert<FunctionConstantOpConversion<ConcreteToBConcreteTypeConverter>>(
&getContext(), converter);
target.addDynamicallyLegalOp<mlir::scf::ForOp>([&](mlir::scf::ForOp forOp) {
return converter.isLegal(forOp.getInitArgs().getTypes()) &&
@@ -969,19 +988,44 @@ void ConcreteToBConcretePass::runOnOperation() {
converter.isLegal(op->getOperandTypes());
});
// Conversion of RT Dialect Ops
patterns.add<
mlir::concretelang::GenericTypeConverterPattern<mlir::func::ReturnOp>,
mlir::concretelang::GenericTypeConverterPattern<mlir::scf::YieldOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::DataflowTaskOp>,
mlir::concretelang::RT::MakeReadyFutureOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::DataflowYieldOp>>(&getContext(), converter);
// Conversion of RT Dialect Ops
mlir::concretelang::RT::AwaitFutureOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::CreateAsyncTaskOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::BuildReturnPtrPlaceholderOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::DerefReturnPtrPlaceholderOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::WorkFunctionReturnOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::RegisterTaskWorkFunctionOp>>(&getContext(),
converter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::DataflowTaskOp>(target, converter);
mlir::concretelang::RT::MakeReadyFutureOp>(target, converter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::DataflowYieldOp>(target, converter);
mlir::concretelang::RT::AwaitFutureOp>(target, converter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::CreateAsyncTaskOp>(target, converter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::BuildReturnPtrPlaceholderOp>(target, converter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>(
target, converter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::DerefReturnPtrPlaceholderOp>(target, converter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::WorkFunctionReturnOp>(target, converter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::RegisterTaskWorkFunctionOp>(target, converter);
// Apply conversion
if (mlir::applyPartialConversion(op, target, std::move(patterns))

View File

@@ -14,11 +14,14 @@
#include "concretelang/Conversion/FHEToTFHE/Patterns.h"
#include "concretelang/Conversion/Passes.h"
#include "concretelang/Conversion/Utils/FuncConstOpConversion.h"
#include "concretelang/Conversion/Utils/RegionOpTypeConverterPattern.h"
#include "concretelang/Conversion/Utils/TensorOpTypeConversion.h"
#include "concretelang/Dialect/FHE/IR/FHEDialect.h"
#include "concretelang/Dialect/FHE/IR/FHETypes.h"
#include "concretelang/Dialect/RT/IR/RTDialect.h"
#include "concretelang/Dialect/RT/IR/RTOps.h"
#include "concretelang/Dialect/RT/IR/RTTypes.h"
#include "concretelang/Dialect/TFHE/IR/TFHEDialect.h"
#include "concretelang/Dialect/TFHE/IR/TFHETypes.h"
@@ -53,6 +56,16 @@ public:
eint.getContext(), eint));
return r;
});
addConversion([&](mlir::concretelang::RT::FutureType type) {
return mlir::concretelang::RT::FutureType::get(
this->convertType(type.dyn_cast<mlir::concretelang::RT::FutureType>()
.getElementType()));
});
addConversion([&](mlir::concretelang::RT::PointerType type) {
return mlir::concretelang::RT::PointerType::get(
this->convertType(type.dyn_cast<mlir::concretelang::RT::PointerType>()
.getElementType()));
});
}
};
@@ -269,6 +282,11 @@ struct FHEToTFHEPass : public FHEToTFHEBase<FHEToTFHEPass> {
return converter.isSignatureLegal(funcOp.getFunctionType()) &&
converter.isLegal(&funcOp.getBody());
});
target.addDynamicallyLegalOp<mlir::func::ConstantOp>(
[&](mlir::func::ConstantOp op) {
return FunctionConstantOpConversion<FHEToTFHETypeConverter>::isLegal(
op, converter);
});
// Add all patterns required to lower all ops from `FHE` to
// `TFHE`
@@ -292,6 +310,8 @@ struct FHEToTFHEPass : public FHEToTFHEBase<FHEToTFHEPass> {
patterns.add<SubEintOpPattern>(&getContext());
patterns.add<SubEintIntOpPattern>(&getContext());
patterns.add<FunctionConstantOpConversion<FHEToTFHETypeConverter>>(
&getContext(), converter);
patterns.add<RegionOpTypeConverterPattern<mlir::linalg::GenericOp,
FHEToTFHETypeConverter>>(
@@ -319,16 +339,43 @@ struct FHEToTFHEPass : public FHEToTFHEBase<FHEToTFHEPass> {
patterns, converter);
// Conversion of RT Dialect Ops
patterns.add<mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::DataflowTaskOp>>(patterns.getContext(),
converter);
patterns.add<
mlir::concretelang::GenericTypeConverterPattern<mlir::func::ReturnOp>,
mlir::concretelang::GenericTypeConverterPattern<mlir::scf::YieldOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::MakeReadyFutureOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::AwaitFutureOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::CreateAsyncTaskOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::BuildReturnPtrPlaceholderOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::DerefReturnPtrPlaceholderOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::WorkFunctionReturnOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::RegisterTaskWorkFunctionOp>>(&getContext(),
converter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::DataflowTaskOp>(target, converter);
patterns.add<mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::DataflowYieldOp>>(patterns.getContext(),
converter);
mlir::concretelang::RT::MakeReadyFutureOp>(target, converter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::DataflowYieldOp>(target, converter);
mlir::concretelang::RT::AwaitFutureOp>(target, converter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::CreateAsyncTaskOp>(target, converter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::BuildReturnPtrPlaceholderOp>(target, converter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>(
target, converter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::DerefReturnPtrPlaceholderOp>(target, converter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::WorkFunctionReturnOp>(target, converter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::RegisterTaskWorkFunctionOp>(target, converter);
// Apply conversion
if (mlir::applyPartialConversion(op, target, std::move(patterns))

View File

@@ -7,6 +7,7 @@
#include "mlir/Transforms/DialectConversion.h"
#include "concretelang/Conversion/Passes.h"
#include "concretelang/Conversion/Utils/FuncConstOpConversion.h"
#include "concretelang/Conversion/Utils/GenericOpTypeConversionPattern.h"
#include "concretelang/Conversion/Utils/RegionOpTypeConverterPattern.h"
#include "concretelang/Conversion/Utils/TensorOpTypeConversion.h"
@@ -52,6 +53,16 @@ public:
this->glweInterPBSType(glwe));
return r;
});
addConversion([&](mlir::concretelang::RT::FutureType type) {
return mlir::concretelang::RT::FutureType::get(
this->convertType(type.dyn_cast<mlir::concretelang::RT::FutureType>()
.getElementType()));
});
addConversion([&](mlir::concretelang::RT::PointerType type) {
return mlir::concretelang::RT::PointerType::get(
this->convertType(type.dyn_cast<mlir::concretelang::RT::PointerType>()
.getElementType()));
});
}
TFHE::GLWECipherTextType glweInterPBSType(GLWECipherTextType &type) {
@@ -293,6 +304,14 @@ void TFHEGlobalParametrizationPass::runOnOperation() {
return converter.isSignatureLegal(funcOp.getFunctionType()) &&
converter.isLegal(&funcOp.getBody());
});
target.addDynamicallyLegalOp<mlir::func::ConstantOp>(
[&](mlir::func::ConstantOp op) {
return FunctionConstantOpConversion<
TFHEGlobalParametrizationTypeConverter>::isLegal(op, converter);
});
patterns.add<
FunctionConstantOpConversion<TFHEGlobalParametrizationTypeConverter>>(
&getContext(), converter);
mlir::populateFunctionOpInterfaceTypeConversionPattern<mlir::func::FuncOp>(
patterns, converter);
@@ -354,16 +373,43 @@ void TFHEGlobalParametrizationPass::runOnOperation() {
patterns, target, converter);
// Conversion of RT Dialect Ops
patterns.add<mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::DataflowTaskOp>>(patterns.getContext(),
converter);
patterns.add<
mlir::concretelang::GenericTypeConverterPattern<mlir::func::ReturnOp>,
mlir::concretelang::GenericTypeConverterPattern<mlir::scf::YieldOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::MakeReadyFutureOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::AwaitFutureOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::CreateAsyncTaskOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::BuildReturnPtrPlaceholderOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::DerefReturnPtrPlaceholderOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::WorkFunctionReturnOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::RegisterTaskWorkFunctionOp>>(&getContext(),
converter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::DataflowTaskOp>(target, converter);
patterns.add<mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::DataflowYieldOp>>(patterns.getContext(),
converter);
mlir::concretelang::RT::MakeReadyFutureOp>(target, converter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::DataflowYieldOp>(target, converter);
mlir::concretelang::RT::AwaitFutureOp>(target, converter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::CreateAsyncTaskOp>(target, converter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::BuildReturnPtrPlaceholderOp>(target, converter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>(
target, converter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::DerefReturnPtrPlaceholderOp>(target, converter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::WorkFunctionReturnOp>(target, converter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::RegisterTaskWorkFunctionOp>(target, converter);
// Apply conversion
if (mlir::applyPartialConversion(op, target, std::move(patterns))

View File

@@ -10,6 +10,7 @@
#include "concretelang/Conversion/Passes.h"
#include "concretelang/Conversion/TFHEToConcrete/Patterns.h"
#include "concretelang/Conversion/Utils/FuncConstOpConversion.h"
#include "concretelang/Conversion/Utils/RegionOpTypeConverterPattern.h"
#include "concretelang/Conversion/Utils/TensorOpTypeConversion.h"
#include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h"
@@ -50,6 +51,16 @@ public:
mlir::concretelang::convertTypeToLWE(glwe.getContext(), glwe));
return r;
});
addConversion([&](mlir::concretelang::RT::FutureType type) {
return mlir::concretelang::RT::FutureType::get(
this->convertType(type.dyn_cast<mlir::concretelang::RT::FutureType>()
.getElementType()));
});
addConversion([&](mlir::concretelang::RT::PointerType type) {
return mlir::concretelang::RT::PointerType::get(
this->convertType(type.dyn_cast<mlir::concretelang::RT::PointerType>()
.getElementType()));
});
}
};
@@ -174,10 +185,17 @@ void TFHEToConcretePass::runOnOperation() {
return converter.isSignatureLegal(funcOp.getFunctionType()) &&
converter.isLegal(&funcOp.getBody());
});
target.addDynamicallyLegalOp<mlir::func::ConstantOp>(
[&](mlir::func::ConstantOp op) {
return FunctionConstantOpConversion<
TFHEToConcreteTypeConverter>::isLegal(op, converter);
});
// Add all patterns required to lower all ops from `TFHE` to
// `Concrete`
mlir::RewritePatternSet patterns(&getContext());
patterns.add<FunctionConstantOpConversion<TFHEToConcreteTypeConverter>>(
&getContext(), converter);
populateWithGeneratedTFHEToConcrete(patterns);
patterns.add<mlir::concretelang::GenericTypeAndOpConverterPattern<
@@ -219,16 +237,43 @@ void TFHEToConcretePass::runOnOperation() {
patterns, converter);
// Conversion of RT Dialect Ops
patterns.add<mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::DataflowTaskOp>>(patterns.getContext(),
converter);
patterns.add<
mlir::concretelang::GenericTypeConverterPattern<mlir::func::ReturnOp>,
mlir::concretelang::GenericTypeConverterPattern<mlir::scf::YieldOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::MakeReadyFutureOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::AwaitFutureOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::CreateAsyncTaskOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::BuildReturnPtrPlaceholderOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::DerefReturnPtrPlaceholderOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::WorkFunctionReturnOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::RegisterTaskWorkFunctionOp>>(&getContext(),
converter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::DataflowTaskOp>(target, converter);
patterns.add<mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::DataflowYieldOp>>(patterns.getContext(),
converter);
mlir::concretelang::RT::MakeReadyFutureOp>(target, converter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::DataflowYieldOp>(target, converter);
mlir::concretelang::RT::AwaitFutureOp>(target, converter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::CreateAsyncTaskOp>(target, converter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::BuildReturnPtrPlaceholderOp>(target, converter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>(
target, converter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::DerefReturnPtrPlaceholderOp>(target, converter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::WorkFunctionReturnOp>(target, converter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::RegisterTaskWorkFunctionOp>(target, converter);
mlir::concretelang::addDynamicallyLegalTypeOp<mlir::func::ReturnOp>(
target, converter);

View File

@@ -32,28 +32,13 @@ struct AddRuntimeContextToFuncOpPattern
rewriter.getType<mlir::concretelang::Concrete::ContextType>());
mlir::FunctionType newFuncTy = rewriter.getType<mlir::FunctionType>(
newInputs, oldFuncType.getResults());
// Create the new func
mlir::func::FuncOp newFuncOp = rewriter.create<mlir::func::FuncOp>(
oldFuncOp.getLoc(), oldFuncOp.getName(), newFuncTy);
// Create the arguments of the new func
mlir::Region &newFuncBody = newFuncOp.getBody();
mlir::Block *newFuncEntryBlock = new mlir::Block();
llvm::SmallVector<mlir::Location> locations(newFuncTy.getInputs().size(),
oldFuncOp.getLoc());
rewriter.updateRootInPlace(oldFuncOp,
[&] { oldFuncOp.setType(newFuncTy); });
oldFuncOp.getBody().front().addArgument(
rewriter.getType<mlir::concretelang::Concrete::ContextType>(),
oldFuncOp.getLoc());
newFuncEntryBlock->addArguments(newFuncTy.getInputs(), locations);
newFuncBody.push_back(newFuncEntryBlock);
// Clone the old body to the new one
mlir::BlockAndValueMapping map;
for (auto arg : llvm::enumerate(oldFuncOp.getArguments())) {
map.map(arg.value(), newFuncEntryBlock->getArgument(arg.index()));
}
for (auto &op : oldFuncOp.getBody().front()) {
newFuncEntryBlock->push_back(op.clone(map));
}
rewriter.eraseOp(oldFuncOp);
return mlir::success();
}
@@ -72,6 +57,50 @@ struct AddRuntimeContextToFuncOpPattern
}
};
namespace {
struct FunctionConstantOpConversion
: public mlir::OpRewritePattern<mlir::func::ConstantOp> {
FunctionConstantOpConversion(mlir::MLIRContext *ctx,
mlir::PatternBenefit benefit = 1)
: ::mlir::OpRewritePattern<mlir::func::ConstantOp>(ctx, benefit) {}
::mlir::LogicalResult
matchAndRewrite(mlir::func::ConstantOp op,
mlir::PatternRewriter &rewriter) const override {
auto symTab = mlir::SymbolTable::getNearestSymbolTable(op);
auto funcOp = mlir::SymbolTable::lookupSymbolIn(symTab, op.getValue());
assert(funcOp &&
"Function symbol missing in symbol table for function constant op.");
mlir::FunctionType funType = mlir::cast<mlir::func::FuncOp>(funcOp)
.getFunctionType()
.cast<mlir::FunctionType>();
mlir::SmallVector<mlir::Type> newInputs(funType.getInputs().begin(),
funType.getInputs().end());
newInputs.push_back(
rewriter.getType<mlir::concretelang::Concrete::ContextType>());
mlir::FunctionType newFuncTy =
rewriter.getType<mlir::FunctionType>(newInputs, funType.getResults());
rewriter.updateRootInPlace(op, [&] { op.getResult().setType(newFuncTy); });
return mlir::success();
}
static bool isLegal(mlir::func::ConstantOp fun) {
auto symTab = mlir::SymbolTable::getNearestSymbolTable(fun);
auto funcOp = mlir::SymbolTable::lookupSymbolIn(symTab, fun.getValue());
assert(funcOp &&
"Function symbol missing in symbol table for function constant op.");
mlir::FunctionType funType = mlir::cast<mlir::func::FuncOp>(funcOp)
.getFunctionType()
.cast<mlir::FunctionType>();
if ((AddRuntimeContextToFuncOpPattern::isLegal(
mlir::cast<mlir::func::FuncOp>(funcOp)) &&
fun.getType() == funType) ||
fun.getType() != funType)
return true;
return false;
}
};
} // namespace
struct AddRuntimeContextPass
: public AddRuntimeContextBase<AddRuntimeContextPass> {
void runOnOperation() final;
@@ -90,8 +119,13 @@ void AddRuntimeContextPass::runOnOperation() {
[&](mlir::func::FuncOp funcOp) {
return AddRuntimeContextToFuncOpPattern::isLegal(funcOp);
});
target.addDynamicallyLegalOp<mlir::func::ConstantOp>(
[&](mlir::func::ConstantOp op) {
return FunctionConstantOpConversion::isLegal(op);
});
patterns.add<AddRuntimeContextToFuncOpPattern>(patterns.getContext());
patterns.add<FunctionConstantOpConversion>(patterns.getContext());
// Apply the conversion
if (mlir::applyPartialConversion(op, target, std::move(patterns))

View File

@@ -153,6 +153,9 @@ mlir::Value getContextArgument(mlir::Operation *op) {
mlir::Block *block = op->getBlock();
while (block != nullptr) {
if (llvm::isa<mlir::func::FuncOp>(block->getParentOp())) {
block = &mlir::cast<mlir::func::FuncOp>(block->getParentOp())
.getBody()
.front();
auto context =
std::find_if(block->getArguments().rbegin(),
@@ -160,7 +163,6 @@ mlir::Value getContextArgument(mlir::Operation *op) {
return arg.getType()
.isa<mlir::concretelang::Concrete::ContextType>();
});
assert(context != block->getArguments().rend() &&
"Cannot find the Concrete.context");

View File

@@ -10,11 +10,16 @@
#include <concretelang/Dialect/RT/IR/RTOps.h>
#include <concretelang/Dialect/RT/IR/RTTypes.h>
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/Dialect/Func/Transforms/Passes.h"
#include "mlir/Transforms/DialectConversion.h"
#include <concretelang/Conversion/Utils/FuncConstOpConversion.h>
#include <concretelang/Conversion/Utils/GenericOpTypeConversionPattern.h>
#include <llvm/IR/Instructions.h>
#include <mlir/Dialect/Bufferization/IR/Bufferization.h>
#include <mlir/Dialect/Bufferization/Transforms/Bufferize.h>
#include <mlir/Dialect/Bufferization/Transforms/Passes.h>
#include <mlir/Dialect/Func/IR/FuncOps.h>
#include <mlir/Dialect/MemRef/IR/MemRef.h>
#include <mlir/IR/BlockAndValueMapping.h>
#include <mlir/IR/Builders.h>
@@ -27,68 +32,35 @@ namespace mlir {
namespace concretelang {
namespace {
class BufferizeDataflowYieldOp
: public OpConversionPattern<RT::DataflowYieldOp> {
class BufferizeRTTypesConverter
: public mlir::bufferization::BufferizeTypeConverter {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(RT::DataflowYieldOp op, RT::DataflowYieldOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<RT::DataflowYieldOp>(op, mlir::TypeRange(),
adaptor.getOperands());
return success();
BufferizeRTTypesConverter() {
addConversion([&](mlir::concretelang::RT::FutureType type) {
return mlir::concretelang::RT::FutureType::get(
this->convertType(type.dyn_cast<mlir::concretelang::RT::FutureType>()
.getElementType()));
});
addConversion([&](mlir::concretelang::RT::PointerType type) {
return mlir::concretelang::RT::PointerType::get(
this->convertType(type.dyn_cast<mlir::concretelang::RT::PointerType>()
.getElementType()));
});
addConversion([&](mlir::FunctionType type) {
SignatureConversion result(type.getNumInputs());
mlir::SmallVector<mlir::Type, 1> newResults;
if (failed(this->convertSignatureArgs(type.getInputs(), result)) ||
failed(this->convertTypes(type.getResults(), newResults)))
return type;
return mlir::FunctionType::get(type.getContext(),
result.getConvertedTypes(), newResults);
});
}
};
} // namespace
namespace {
class BufferizeDataflowTaskOp : public OpConversionPattern<RT::DataflowTaskOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(RT::DataflowTaskOp op, RT::DataflowTaskOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
mlir::OpBuilder::InsertionGuard guard(rewriter);
SmallVector<Type> newResults;
(void)getTypeConverter()->convertTypes(op.getResultTypes(), newResults);
auto newop = rewriter.create<RT::DataflowTaskOp>(op.getLoc(), newResults,
adaptor.getOperands());
// We cannot clone here as cloned ops must be legalized (so this
// would break on the YieldOp). Instead use mergeBlocks which
// moves the ops instead of cloning.
rewriter.mergeBlocks(op.getBody(), newop.getBody(),
newop.getBody()->getArguments());
// Because of previous bufferization there are buffer cast ops
// that have been generated for the previously tensor results of
// some tasks. These cannot just be replaced directly as the
// task's results would still be live.
for (auto res : llvm::enumerate(op.getResults())) {
// If this result is getting bufferized ...
if (res.value().getType() !=
getTypeConverter()->convertType(res.value().getType())) {
for (auto &use : llvm::make_early_inc_range(res.value().getUses())) {
// ... and its uses are in `ToMemrefOp`s, then we
// replace further uses of the buffer cast.
if (isa<mlir::bufferization::ToMemrefOp>(use.getOwner())) {
rewriter.replaceOp(use.getOwner(), {newop.getResult(res.index())});
}
}
}
}
rewriter.replaceOp(op, {newop.getResults()});
return success();
}
};
} // namespace
void populateRTBufferizePatterns(
mlir::bufferization::BufferizeTypeConverter &typeConverter,
RewritePatternSet &patterns) {
patterns.add<BufferizeDataflowYieldOp, BufferizeDataflowTaskOp>(
typeConverter, patterns.getContext());
}
namespace {
/// For documentation see Autopar.td
struct BufferizeDataflowTaskOpsPass
@@ -97,15 +69,75 @@ struct BufferizeDataflowTaskOpsPass
void runOnOperation() override {
auto module = getOperation();
auto *context = &getContext();
mlir::bufferization::BufferizeTypeConverter typeConverter;
BufferizeRTTypesConverter typeConverter;
RewritePatternSet patterns(context);
ConversionTarget target(*context);
populateRTBufferizePatterns(typeConverter, patterns);
populateFunctionOpInterfaceTypeConversionPattern<mlir::func::FuncOp>(
patterns, typeConverter);
patterns.add<FunctionConstantOpConversion<BufferizeRTTypesConverter>>(
context, typeConverter);
// Forbid all RT ops that still use/return tensors
target.addDynamicallyLegalDialect<RT::RTDialect>(
[&](Operation *op) { return typeConverter.isLegal(op); });
target.addDynamicallyLegalDialect<mlir::func::FuncDialect>([&](Operation
*op) {
if (auto fun = dyn_cast_or_null<mlir::func::FuncOp>(op))
return typeConverter.isSignatureLegal(fun.getFunctionType()) &&
typeConverter.isLegal(&fun.getBody());
if (auto fun = dyn_cast_or_null<mlir::func::ConstantOp>(op))
return FunctionConstantOpConversion<BufferizeRTTypesConverter>::isLegal(
fun, typeConverter);
return typeConverter.isLegal(op);
});
patterns.add<
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::DataflowTaskOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::DataflowYieldOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::MakeReadyFutureOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::AwaitFutureOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::CreateAsyncTaskOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::BuildReturnPtrPlaceholderOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::DerefReturnPtrPlaceholderOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::WorkFunctionReturnOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::RegisterTaskWorkFunctionOp>>(&getContext(),
typeConverter);
// Conversion of RT Dialect Ops
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::DataflowTaskOp>(target, typeConverter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::DataflowYieldOp>(target, typeConverter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::MakeReadyFutureOp>(target, typeConverter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::AwaitFutureOp>(target, typeConverter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::CreateAsyncTaskOp>(target, typeConverter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::BuildReturnPtrPlaceholderOp>(target,
typeConverter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>(
target, typeConverter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::DerefReturnPtrPlaceholderOp>(target,
typeConverter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::WorkFunctionReturnOp>(target, typeConverter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::RegisterTaskWorkFunctionOp>(target,
typeConverter);
if (failed(applyPartialConversion(module, target, std::move(patterns))))
signalPassFailure();

View File

@@ -18,7 +18,6 @@
#include <mlir/Dialect/Arithmetic/IR/Arithmetic.h>
#include <mlir/Dialect/Func/IR/FuncOps.h>
#include <mlir/Dialect/MemRef/IR/MemRef.h>
#include <mlir/IR/Attributes.h>
#include <mlir/IR/BlockAndValueMapping.h>
#include <mlir/IR/Builders.h>
@@ -61,9 +60,8 @@ static bool isAggregatingBeneficiary(Operation *op) {
return isa<FHE::ZeroEintOp, FHE::ZeroTensorOp, FHE::AddEintIntOp,
FHE::AddEintOp, FHE::SubIntEintOp, FHE::SubEintIntOp,
FHE::MulEintIntOp, FHE::SubEintOp, FHE::NegEintOp,
FHELinalg::FromElementOp, arith::ConstantOp, memref::DimOp,
arith::SelectOp, mlir::arith::CmpIOp, memref::GetGlobalOp,
memref::CastOp>(op);
FHELinalg::FromElementOp, arith::ConstantOp, arith::SelectOp,
mlir::arith::CmpIOp>(op);
}
static bool
@@ -95,87 +93,6 @@ aggregateBeneficiaryOps(Operation *op, SetVector<Operation *> &beneficiaryOps,
return true;
}
static bool isFunctionCallName(OpOperand *use, StringRef name) {
func::CallOp call = dyn_cast_or_null<mlir::func::CallOp>(use->getOwner());
if (!call)
return false;
SymbolRefAttr sym = call.getCallableForCallee().dyn_cast<SymbolRefAttr>();
if (!sym)
return false;
func::FuncOp called = dyn_cast_or_null<func::FuncOp>(
SymbolTable::lookupNearestSymbolFrom(call, sym));
if (!called)
return false;
return called.getName() == name;
}
static void getAliasedUses(Value val, DenseSet<OpOperand *> &aliasedUses) {
for (auto &use : val.getUses()) {
aliasedUses.insert(&use);
if (dyn_cast<ViewLikeOpInterface>(use.getOwner()))
getAliasedUses(use.getOwner()->getResult(0), aliasedUses);
}
}
static bool aggregateOutputMemrefAllocations(
Operation *op, SetVector<Operation *> &beneficiaryOps,
llvm::SmallPtrSetImpl<Value> &availableValues, RT::DataflowTaskOp taskOp) {
if (beneficiaryOps.count(op))
return true;
if (!isa<mlir::memref::AllocOp>(op))
return false;
Value val = op->getResults().front();
DenseSet<OpOperand *> aliasedUses;
getAliasedUses(val, aliasedUses);
// Helper function checking if a memref use writes to memory
auto hasMemoryWriteEffect = [&](OpOperand *use) {
// Call ops targeting concrete-ffi do not have memory effects
// interface, so handle apart.
// TODO: this could be handled better in BConcrete or higher.
if (isFunctionCallName(use, "memref_expand_lut_in_trivial_glwe_ct_u64") ||
isFunctionCallName(use, "memref_add_lwe_ciphertexts_u64") ||
isFunctionCallName(use, "memref_add_plaintext_lwe_ciphertext_u64") ||
isFunctionCallName(use, "memref_mul_cleartext_lwe_ciphertext_u64") ||
isFunctionCallName(use, "memref_negate_lwe_ciphertext_u64") ||
isFunctionCallName(use, "memref_keyswitch_lwe_u64") ||
isFunctionCallName(use, "memref_bootstrap_lwe_u64"))
if (use->getOwner()->getOperand(0) == use->get())
return true;
if (isFunctionCallName(use, "memref_copy_one_rank"))
if (use->getOwner()->getOperand(1) == use->get())
return true;
// Otherwise we rely on the memory effect interface
auto effectInterface = dyn_cast<MemoryEffectOpInterface>(use->getOwner());
if (!effectInterface)
return false;
SmallVector<MemoryEffects::EffectInstance, 2> effects;
effectInterface.getEffects(effects);
for (auto eff : effects)
if (isa<MemoryEffects::Write>(eff.getEffect()) &&
eff.getValue() == use->get())
return true;
return false;
};
// We need to check if this allocated memref is written in this task.
// TODO: for now we'll assume that we don't do partial writes or read/writes.
for (auto use : aliasedUses)
if (hasMemoryWriteEffect(use) &&
use->getOwner()->getParentOfType<RT::DataflowTaskOp>() == taskOp) {
// We will sink the operation, mark its results as now available.
beneficiaryOps.insert(op);
for (Value result : op->getResults())
availableValues.insert(result);
return true;
}
return false;
}
LogicalResult coarsenDFTask(RT::DataflowTaskOp taskOp) {
Region &taskOpBody = taskOp.body();
@@ -191,8 +108,6 @@ LogicalResult coarsenDFTask(RT::DataflowTaskOp taskOp) {
if (!operandOp)
continue;
aggregateBeneficiaryOps(operandOp, toBeSunk, availableValues);
aggregateOutputMemrefAllocations(operandOp, toBeSunk, availableValues,
taskOp);
}
// Insert operations so that the defs get cloned before uses.
@@ -283,45 +198,5 @@ std::unique_ptr<mlir::Pass> createBuildDataflowTaskGraphPass(bool debug) {
return std::make_unique<BuildDataflowTaskGraphPass>(debug);
}
namespace {
/// For documentation see Autopar.td
struct FixupDataflowTaskOpsPass
: public FixupDataflowTaskOpsBase<FixupDataflowTaskOpsPass> {
void runOnOperation() override {
auto module = getOperation();
module->walk([](RT::DataflowTaskOp op) {
assert(!failed(coarsenDFTask(op)) &&
"Failing to sink operations into DFT");
});
// Finally clear up any remaining alloc/dealloc ops that are
// meaningless
SetVector<Operation *> eraseOps;
module->walk([&](memref::AllocOp op) {
// If this memref.alloc's only use left is the
// dealloc, erase both.
if (op->hasOneUse() &&
isa<mlir::memref::DeallocOp>(op->use_begin()->getOwner())) {
eraseOps.insert(op->use_begin()->getOwner());
eraseOps.insert(op);
}
});
for (auto op : eraseOps)
op->erase();
}
FixupDataflowTaskOpsPass(bool debug) : debug(debug){};
protected:
bool debug;
};
} // end anonymous namespace
std::unique_ptr<mlir::Pass> createFixupDataflowTaskOpsPass(bool debug) {
return std::make_unique<FixupDataflowTaskOpsPass>(debug);
}
} // end namespace concretelang
} // end namespace mlir

View File

@@ -6,6 +6,7 @@
#include <iostream>
#include <concretelang/Conversion/Tools.h>
#include <concretelang/Conversion/Utils/GenericOpTypeConversionPattern.h>
#include <concretelang/Dialect/Concrete/IR/ConcreteDialect.h>
#include <concretelang/Dialect/Concrete/IR/ConcreteOps.h>
#include <concretelang/Dialect/Concrete/IR/ConcreteTypes.h>
@@ -18,9 +19,7 @@
#include <concretelang/Dialect/RT/IR/RTTypes.h>
#include <concretelang/Runtime/DFRuntime.hpp>
#include <concretelang/Support/math.h>
#include <mlir/IR/BuiltinOps.h>
#include <concretelang/Conversion/Utils/GenericOpTypeConversionPattern.h>
#include <llvm/IR/Instructions.h>
#include <mlir/Analysis/DataFlowAnalysis.h>
#include <mlir/Conversion/LLVMCommon/ConversionTarget.h>
@@ -39,7 +38,9 @@
#include <mlir/IR/BlockAndValueMapping.h>
#include <mlir/IR/Builders.h>
#include <mlir/IR/BuiltinAttributes.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/SymbolTable.h>
#include <mlir/Interfaces/ViewLikeInterface.h>
#include <mlir/Pass/PassManager.h>
#include <mlir/Support/LLVM.h>
#include <mlir/Support/LogicalResult.h>
@@ -67,10 +68,10 @@ static func::FuncOp outlineWorkFunction(RT::DataflowTaskOp DFTOp,
// types, which will be changed to use an indirection when lowering.
SmallVector<Type, 4> operandTypes;
operandTypes.reserve(DFTOp.getNumOperands() + DFTOp.getNumResults());
for (Value operand : DFTOp.getOperands())
operandTypes.push_back(RT::PointerType::get(operand.getType()));
for (Value res : DFTOp.getResults())
operandTypes.push_back(RT::PointerType::get(res.getType()));
for (Value operand : DFTOp.getOperands())
operandTypes.push_back(RT::PointerType::get(operand.getType()));
FunctionType type = FunctionType::get(DFTOp.getContext(), operandTypes, {});
auto outlinedFunc = builder.create<func::FuncOp>(loc, workFunctionName, type);
@@ -82,6 +83,7 @@ static func::FuncOp outlineWorkFunction(RT::DataflowTaskOp DFTOp,
outlinedFuncBody.push_back(outlinedEntryBlock);
BlockAndValueMapping map;
int input_offset = DFTOp.getNumResults();
Block &entryBlock = outlinedFuncBody.front();
builder.setInsertionPointToStart(&entryBlock);
for (auto operand : llvm::enumerate(DFTOp.getOperands())) {
@@ -89,7 +91,7 @@ static func::FuncOp outlineWorkFunction(RT::DataflowTaskOp DFTOp,
auto derefdop =
builder.create<RT::DerefWorkFunctionArgumentPtrPlaceholderOp>(
DFTOp.getLoc(), operand.value().getType(),
entryBlock.getArgument(operand.index()));
entryBlock.getArgument(operand.index() + input_offset));
map.map(operand.value(), derefdop->getResult(0));
}
DFTOpBody.cloneInto(&outlinedFuncBody, map);
@@ -99,17 +101,14 @@ static func::FuncOp outlineWorkFunction(RT::DataflowTaskOp DFTOp,
builder.setInsertionPointToEnd(&entryBlock);
builder.create<cf::BranchOp>(loc, clonedDFTOpEntry);
// TODO: we use a WorkFunctionReturnOp to tie return to the
// corresponding argument. This can be lowered to a copy/deref for
// shared memory and pointers, but needs to be handled for
// distributed memory.
// WorkFunctionReturnOp ties return to the corresponding argument.
// This is lowered to a copy/deref for shared memory and pointers,
// and handled in the serializer for distributed memory.
outlinedFunc.walk([&](RT::DataflowYieldOp op) {
OpBuilder replacer(op);
int output_offset = DFTOp.getNumOperands();
for (auto ret : llvm::enumerate(op.getOperands()))
replacer.create<RT::WorkFunctionReturnOp>(
op.getLoc(), ret.value(),
outlinedFunc.getArgument(ret.index() + output_offset));
op.getLoc(), ret.value(), outlinedFunc.getArgument(ret.index()));
replacer.create<func::ReturnOp>(op.getLoc());
op.erase();
});
@@ -126,33 +125,37 @@ static void replaceAllUsesInDFTsInRegionWith(Value orig, Value replacement,
}
}
static mlir::Type stripType(mlir::Type type) {
if (type.isa<RT::FutureType>())
return stripType(type.dyn_cast<RT::FutureType>().getElementType());
if (type.isa<RT::PointerType>())
return stripType(type.dyn_cast<RT::PointerType>().getElementType());
return type;
}
// TODO: Fix type sizes. For now we're using some default values.
static std::pair<mlir::Value, mlir::Value>
static std::pair<Value, Value>
getTaskArgumentSizeAndType(Value val, Location loc, OpBuilder builder) {
DataLayout dataLayout = DataLayout::closest(val.getDefiningOp());
Type type = (val.getType().isa<RT::FutureType>())
? val.getType().dyn_cast<RT::FutureType>().getElementType()
: val.getType();
Type type = stripType(val.getType());
// In the case of memref, we need to determine how much space
// (conservatively) we need to store the memref itself. Overshooting
// by a few bytes should not be an issue, so the main thing is to
// properly account for the rank.
if (type.isa<mlir::MemRefType>()) {
// Space for the allocated and aligned pointers, and offset
Value ptrs_offset =
builder.create<arith::ConstantOp>(loc, builder.getI64IntegerAttr(24));
// For the sizes and shapes arrays, we need 2*8 = 16 times the rank in bytes
Value multiplier =
builder.create<arith::ConstantOp>(loc, builder.getI64IntegerAttr(16));
unsigned _rank = type.dyn_cast<mlir::MemRefType>().getRank();
Value rank = builder.create<arith::ConstantOp>(
loc, builder.getI64IntegerAttr(_rank));
Value sizes_shapes = builder.create<LLVM::MulOp>(loc, rank, multiplier);
Value typeSize =
builder.create<LLVM::AddOp>(loc, ptrs_offset, sizes_shapes);
// Space for the allocated and aligned pointers, and offset plus
// rank * sizes and strides
size_t element_size;
unsigned rank = type.dyn_cast<mlir::MemRefType>().getRank();
Type elementType = type.dyn_cast<mlir::MemRefType>().getElementType();
element_size = dataLayout.getTypeSize(elementType);
size_t size = 24 + 16 * rank;
Value typeSize =
builder.create<arith::ConstantOp>(loc, builder.getI64IntegerAttr(size));
// Assume here that the base type is a simple scalar-type or at
// least its size can be determined.
// size_t elementAttr = dataLayout.getTypeSize(elementType);
@@ -160,7 +163,6 @@ getTaskArgumentSizeAndType(Value val, Location loc, OpBuilder builder) {
// elementAttr <<= 8;
// elementAttr |= _DFR_TASK_ARG_MEMREF;
uint64_t elementAttr = 0;
size_t element_size = dataLayout.getTypeSize(elementType);
elementAttr =
dfr::_dfr_set_arg_type(elementAttr, dfr::_DFR_TASK_ARG_MEMREF);
elementAttr = dfr::_dfr_set_memref_element_size(elementAttr, element_size);
@@ -169,34 +171,19 @@ getTaskArgumentSizeAndType(Value val, Location loc, OpBuilder builder) {
return std::pair<mlir::Value, mlir::Value>(typeSize, arg_type);
}
// Unranked memrefs should be lowered to just pointer + size, so we need 16
// bytes.
assert(!type.isa<mlir::UnrankedMemRefType>() &&
"UnrankedMemRefType not currently supported");
if (type.isa<mlir::concretelang::Concrete::ContextType>()) {
Value arg_type = builder.create<arith::ConstantOp>(
loc, builder.getI64IntegerAttr(dfr::_DFR_TASK_ARG_CONTEXT));
Value typeSize =
builder.create<arith::ConstantOp>(loc, builder.getI64IntegerAttr(8));
return std::pair<mlir::Value, mlir::Value>(typeSize, arg_type);
}
Value arg_type = builder.create<arith::ConstantOp>(
loc, builder.getI64IntegerAttr(dfr::_DFR_TASK_ARG_BASE));
// FHE types are converted to pointers, so we take their size as 8
// bytes until we can get the actual size of the actual types.
if (type.isa<mlir::concretelang::Concrete::LweCiphertextType>() ||
type.isa<mlir::concretelang::Concrete::GlweCiphertextType>() ||
type.isa<mlir::concretelang::Concrete::PlaintextType>()) {
Value result =
builder.create<arith::ConstantOp>(loc, builder.getI64IntegerAttr(8));
return std::pair<mlir::Value, mlir::Value>(result, arg_type);
} else if (type.isa<mlir::concretelang::Concrete::ContextType>()) {
Value arg_type = builder.create<arith::ConstantOp>(
loc, builder.getI64IntegerAttr(dfr::_DFR_TASK_ARG_CONTEXT));
Value result =
builder.create<arith::ConstantOp>(loc, builder.getI64IntegerAttr(8));
return std::pair<mlir::Value, mlir::Value>(result, arg_type);
}
// For all other types, get type size.
Value result = builder.create<arith::ConstantOp>(
Value typeSize = builder.create<arith::ConstantOp>(
loc, builder.getI64IntegerAttr(dataLayout.getTypeSize(type)));
return std::pair<mlir::Value, mlir::Value>(result, arg_type);
return std::pair<mlir::Value, mlir::Value>(typeSize, arg_type);
}
static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp,
@@ -212,7 +199,6 @@ static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp,
if (!val.getType().isa<RT::FutureType>()) {
OpBuilder::InsertionGuard guard(builder);
Type futType = RT::FutureType::get(val.getType());
Value memrefCloned;
// Find out if this value is needed in any other task
SmallVector<Operation *, 2> taskOps;
@@ -224,20 +210,10 @@ static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp,
if (first->getBlock() == op->getBlock() && op->isBeforeInBlock(first))
first = op;
builder.setInsertionPoint(first);
// If we are building a future on a MemRef, then we need to clone
// the memref in order to allow the deallocation pass which does
// not synchronize with task execution.
if (val.getType().isa<mlir::MemRefType>()) {
memrefCloned = builder.create<arith::ConstantOp>(
val.getLoc(), builder.getI64IntegerAttr(1));
} else {
memrefCloned = builder.create<arith::ConstantOp>(
val.getLoc(), builder.getI64IntegerAttr(0));
}
auto mrf = builder.create<RT::MakeReadyFutureOp>(val.getLoc(), futType,
val, memrefCloned);
auto mrf = builder.create<RT::MakeReadyFutureOp>(
val.getLoc(), futType, val,
builder.create<arith::ConstantOp>(val.getLoc(),
builder.getI64IntegerAttr(0)));
replaceAllUsesInDFTsInRegionWith(val, mrf, opBody);
}
}
@@ -246,7 +222,7 @@ static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp,
// DataflowTaskOp. This also includes the necessary handling of
// operands and results (conversion to/from futures and propagation).
SmallVector<Value, 4> catOperands;
int size = 3 + DFTOp.getNumResults() * 3 + DFTOp.getNumOperands() * 3;
int size = 3 + DFTOp.getNumResults() + DFTOp.getNumOperands();
catOperands.reserve(size);
auto fnptr = builder.create<mlir::func::ConstantOp>(
DFTOp.getLoc(), workFunction.getFunctionType(),
@@ -258,12 +234,6 @@ static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp,
catOperands.push_back(fnptr.getResult());
catOperands.push_back(numIns.getResult());
catOperands.push_back(numOuts.getResult());
for (auto operand : DFTOp.getOperands()) {
auto op_size = getTaskArgumentSizeAndType(operand, DFTOp.getLoc(), builder);
catOperands.push_back(operand);
catOperands.push_back(op_size.first);
catOperands.push_back(op_size.second);
}
// We need to adjust the results for the CreateAsyncTaskOp which
// are the work function's returns through pointers passed as
@@ -276,11 +246,11 @@ static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp,
Type futType = RT::PointerType::get(RT::FutureType::get(result.getType()));
auto brpp = builder.create<RT::BuildReturnPtrPlaceholderOp>(DFTOp.getLoc(),
futType);
auto op_size = getTaskArgumentSizeAndType(result, DFTOp.getLoc(), builder);
map.map(result, brpp->getResult(0));
catOperands.push_back(brpp->getResult(0));
catOperands.push_back(op_size.first);
catOperands.push_back(op_size.second);
}
for (auto operand : DFTOp.getOperands()) {
catOperands.push_back(operand);
}
builder.create<RT::CreateAsyncTaskOp>(
DFTOp.getLoc(),
@@ -341,59 +311,6 @@ static func::FuncOp getCalledFunction(CallOpInterface callOp) {
SymbolTable::lookupNearestSymbolFrom(callOp, sym));
}
static void propagateMemRefLayoutInDFTs(RT::DataflowTaskOp op, Value val,
Value newval) {
for (auto &use : llvm::make_early_inc_range(val.getUses()))
if (use.getOwner()->getParentOfType<RT::DataflowTaskOp>() != nullptr) {
OpBuilder builder(use.getOwner());
Value cast_newval = builder.create<mlir::memref::CastOp>(
val.getLoc(), val.getType(), newval);
use.set(cast_newval);
}
}
static void cloneMemRefTaskArgumentsWithIdentityMaps(RT::DataflowTaskOp op) {
OpBuilder builder(op);
for (Value val : op.getOperands()) {
if (val.getType().isa<mlir::MemRefType>()) {
OpBuilder::InsertionGuard guard(builder);
// Find out if this memref is needed in any other task to clone
// before all uses
SmallVector<Operation *, 2> taskOps;
for (auto &use : val.getUses())
if (isa<RT::DataflowTaskOp>(use.getOwner()))
taskOps.push_back(use.getOwner());
Operation *first = op;
for (auto op : taskOps)
if (first->getBlock() == op->getBlock() && op->isBeforeInBlock(first))
first = op;
builder.setInsertionPoint(first);
// Get the type of memref that we will clone. In case this is
// a subview, we discard the mapping so we copy to a contiguous
// layout which pre-serializes this.
MemRefType mrType_base = val.getType().dyn_cast<mlir::MemRefType>();
MemRefType mrType = mrType_base;
if (!mrType_base.getLayout().isIdentity()) {
unsigned rank = mrType_base.getRank();
mrType = MemRefType::Builder(mrType_base)
.setShape(mrType_base.getShape())
.setLayout(AffineMapAttr::get(
builder.getMultiDimIdentityMap(rank)));
}
Value newval = builder.create<mlir::memref::AllocOp>(val.getLoc(), mrType)
.getResult();
builder.create<mlir::memref::CopyOp>(val.getLoc(), val, newval);
// Value cast_newval = builder.create<mlir::memref::CastOp>(val.getLoc(),
// mrType_base, newval);
replaceAllUsesInDFTsInRegionWith(
val, newval, op->getParentOfType<func::FuncOp>().getBody());
propagateMemRefLayoutInDFTs(op, val, newval);
}
}
}
/// For documentation see Autopar.td
struct LowerDataflowTasksPass
: public LowerDataflowTasksBase<LowerDataflowTasksPass> {
@@ -413,8 +330,8 @@ struct LowerDataflowTasksPass
SymbolTable symbolTable = mlir::SymbolTable::getNearestSymbolTable(func);
SmallVector<std::pair<RT::DataflowTaskOp, func::FuncOp>, 4> outliningMap;
// Outline DataflowTaskOp bodies to work functions
func.walk([&](RT::DataflowTaskOp op) {
cloneMemRefTaskArgumentsWithIdentityMaps(op);
auto workFunctionName =
Twine("_dfr_DFT_work_function__") +
Twine(op->getParentOfType<func::FuncOp>().getName()) +
@@ -423,15 +340,15 @@ struct LowerDataflowTasksPass
outlineWorkFunction(op, workFunctionName.str());
outliningMap.push_back(
std::pair<RT::DataflowTaskOp, func::FuncOp>(op, outlinedFunc));
workFunctions.push_back(outlinedFunc);
symbolTable.insert(outlinedFunc);
workFunctions.push_back(outlinedFunc);
return WalkResult::advance();
});
// Lower the DF task ops to RT dialect ops.
for (auto mapping : outliningMap)
lowerDataflowTaskOp(mapping.first, mapping.second);
// Gather all entry points (assuming no recursive calls to entry points)
// Main is always an entry-point - otherwise check if this
// function is called within the module. TODO: we assume no
// recursion.
@@ -449,17 +366,6 @@ struct LowerDataflowTasksPass
});
for (auto entryPoint : entryPoints) {
// Check if this entry point uses a context - do this before we
// remove arguments in remote nodes
int ctxIndex = -1;
for (auto arg : llvm::enumerate(entryPoint.getArguments()))
if (arg.value()
.getType()
.isa<mlir::concretelang::Concrete::ContextType>()) {
ctxIndex = arg.index();
break;
}
// If this is a JIT invocation and we're not on the root node,
// we do not need to do any computation, only register all work
// functions with the runtime system
@@ -486,49 +392,6 @@ struct LowerDataflowTasksPass
// runtime.
for (auto wf : workFunctions)
registerWorkFunction(entryPoint, wf);
// Issue _dfr_start/stop calls for this function
OpBuilder builder(entryPoint.getBody());
builder.setInsertionPointToStart(&entryPoint.getBody().front());
int useDFR = (workFunctions.empty()) ? 0 : 1;
Value useDFRVal = builder.create<arith::ConstantOp>(
entryPoint.getLoc(), builder.getI64IntegerAttr(useDFR));
if (ctxIndex >= 0) {
auto startFunTy =
(dfr::_dfr_is_root_node())
? mlir::FunctionType::get(
entryPoint->getContext(),
{useDFRVal.getType(),
entryPoint.getArgument(ctxIndex).getType()},
{})
: mlir::FunctionType::get(entryPoint->getContext(),
{useDFRVal.getType()}, {});
(void)insertForwardDeclaration(entryPoint, builder, "_dfr_start_c",
startFunTy);
(dfr::_dfr_is_root_node())
? builder.create<mlir::func::CallOp>(
entryPoint.getLoc(), "_dfr_start_c", mlir::TypeRange(),
mlir::ValueRange(
{useDFRVal, entryPoint.getArgument(ctxIndex)}))
: builder.create<mlir::func::CallOp>(entryPoint.getLoc(),
"_dfr_start_c",
mlir::TypeRange(), useDFRVal);
} else {
auto startFunTy = mlir::FunctionType::get(entryPoint->getContext(),
{useDFRVal.getType()}, {});
(void)insertForwardDeclaration(entryPoint, builder, "_dfr_start",
startFunTy);
builder.create<mlir::func::CallOp>(entryPoint.getLoc(), "_dfr_start",
mlir::TypeRange(), useDFRVal);
}
builder.setInsertionPoint(entryPoint.getBody().back().getTerminator());
auto stopFunTy = mlir::FunctionType::get(entryPoint->getContext(),
{useDFRVal.getType()}, {});
(void)insertForwardDeclaration(entryPoint, builder, "_dfr_stop",
stopFunTy);
builder.create<mlir::func::CallOp>(entryPoint.getLoc(), "_dfr_stop",
mlir::TypeRange(), useDFRVal);
}
}
LowerDataflowTasksPass(bool debug) : debug(debug){};
@@ -544,6 +407,188 @@ std::unique_ptr<mlir::Pass> createLowerDataflowTasksPass(bool debug) {
namespace {
// For documentation see Autopar.td
struct StartStopPass : public StartStopBase<StartStopPass> {
void runOnOperation() override {
auto module = getOperation();
int useDFR = 0;
SmallVector<func::FuncOp, 1> entryPoints;
module.walk([&](mlir::func::FuncOp func) {
// Do not add start/stop to work functions - but if any are
// present, then we need to activate the runtime
if (func->getAttr("_dfr_work_function_attribute")) {
useDFR = 1;
} else {
// Main is always an entry-point - otherwise check if this
// function is called within the module. TODO: we assume no
// recursion.
if (func.getName() == "main")
entryPoints.push_back(func);
else {
bool found = false;
module.walk([&](mlir::func::CallOp op) {
if (getCalledFunction(op) == func)
found = true;
});
if (!found)
entryPoints.push_back(func);
}
}
});
for (auto entryPoint : entryPoints) {
// Issue _dfr_start/stop calls for this function
OpBuilder builder(entryPoint.getBody());
builder.setInsertionPointToStart(&entryPoint.getBody().front());
Value useDFRVal = builder.create<arith::ConstantOp>(
entryPoint.getLoc(), builder.getI64IntegerAttr(useDFR));
// Check if this entry point uses a context
Value ctx = nullptr;
if (dfr::_dfr_is_root_node())
for (auto arg : llvm::enumerate(entryPoint.getArguments()))
if (arg.value()
.getType()
.isa<mlir::concretelang::Concrete::ContextType>()) {
ctx = arg.value();
break;
}
if (!ctx)
ctx = builder.create<arith::ConstantOp>(entryPoint.getLoc(),
builder.getI64IntegerAttr(0));
auto startFunTy = mlir::FunctionType::get(
entryPoint->getContext(), {useDFRVal.getType(), ctx.getType()}, {});
(void)insertForwardDeclaration(entryPoint, builder, "_dfr_start",
startFunTy);
builder.create<mlir::func::CallOp>(entryPoint.getLoc(), "_dfr_start",
mlir::TypeRange(),
mlir::ValueRange({useDFRVal, ctx}));
builder.setInsertionPoint(entryPoint.getBody().back().getTerminator());
auto stopFunTy = mlir::FunctionType::get(entryPoint->getContext(),
{useDFRVal.getType()}, {});
(void)insertForwardDeclaration(entryPoint, builder, "_dfr_stop",
stopFunTy);
builder.create<mlir::func::CallOp>(entryPoint.getLoc(), "_dfr_stop",
mlir::TypeRange(), useDFRVal);
}
}
StartStopPass(bool debug) : debug(debug){};
protected:
bool debug;
};
} // namespace
std::unique_ptr<mlir::Pass> createStartStopPass(bool debug) {
return std::make_unique<StartStopPass>(debug);
}
namespace {
// For documentation see Autopar.td
struct FinalizeTaskCreationPass
: public FinalizeTaskCreationBase<FinalizeTaskCreationPass> {
void runOnOperation() override {
auto module = getOperation();
std::vector<Operation *> ops;
module.walk([&](RT::CreateAsyncTaskOp catOp) {
OpBuilder builder(catOp);
SmallVector<Value, 4> operands;
// Determine if this task needs a runtime context
Value ctx = nullptr;
SymbolRefAttr sym =
catOp->getAttr("workfn").dyn_cast_or_null<SymbolRefAttr>();
assert(sym && "Work function symbol attribute missing.");
func::FuncOp workfn = dyn_cast_or_null<func::FuncOp>(
SymbolTable::lookupNearestSymbolFrom(catOp, sym));
assert(workfn && "Task work function missing.");
if (workfn.getNumArguments() > catOp.getNumOperands() - 3)
ctx = *catOp->getParentOfType<func::FuncOp>().getArguments().rbegin();
else
ctx = builder.create<arith::ConstantOp>(catOp.getLoc(),
builder.getI64IntegerAttr(0));
int index = 0;
for (auto op : catOp.getOperands()) {
operands.push_back(op);
// Add index in second position - in all cases to avoid
// checking if needed. It can be null when not relevant.
if (index == 0)
operands.push_back(ctx);
// First three operands are the function pointer, number inputs
// and number outputs - nothing further to do.
if (++index <= 3)
continue;
auto op_size = getTaskArgumentSizeAndType(op, catOp.getLoc(), builder);
operands.push_back(op_size.first);
operands.push_back(op_size.second);
}
builder.create<RT::CreateAsyncTaskOp>(catOp.getLoc(), sym, operands);
ops.push_back(catOp);
});
for (auto op : ops) {
op->erase();
}
// If we are building a future on a MemRef, we need to flatten it.
// TODO: the performance of shared memory can be improved by
// allowing view-like access instead of cloning, but memory
// deallocation needs to be synchronized appropriately
module.walk([&](RT::MakeReadyFutureOp op) {
OpBuilder builder(op);
Value val = op.getOperand(0);
Value clone = op.getOperand(1);
if (val.getType().isa<mlir::MemRefType>()) {
MemRefType mrType_base = val.getType().dyn_cast<mlir::MemRefType>();
MemRefType mrType = mrType_base;
if (!mrType_base.getLayout().isIdentity()) {
unsigned rank = mrType_base.getRank();
mrType = MemRefType::Builder(mrType_base)
.setShape(mrType_base.getShape())
.setLayout(AffineMapAttr::get(
builder.getMultiDimIdentityMap(rank)));
}
// We need to make a copy of this MemRef to allow deallocation
// based on refcounting
Value newval =
builder.create<mlir::memref::AllocOp>(val.getLoc(), mrType)
.getResult();
builder.create<mlir::memref::CopyOp>(val.getLoc(), val, newval);
clone = builder.create<arith::ConstantOp>(op.getLoc(),
builder.getI64IntegerAttr(1));
op->setOperand(0, newval);
op->setOperand(1, clone);
}
});
}
FinalizeTaskCreationPass(bool debug) : debug(debug){};
protected:
bool debug;
};
} // namespace
std::unique_ptr<mlir::Pass> createFinalizeTaskCreationPass(bool debug) {
return std::make_unique<FinalizeTaskCreationPass>(debug);
}
namespace {
static void getAliasedUses(Value val, DenseSet<OpOperand *> &aliasedUses) {
for (auto &use : val.getUses()) {
aliasedUses.insert(&use);
if (dyn_cast<ViewLikeOpInterface>(use.getOwner()))
getAliasedUses(use.getOwner()->getResult(0), aliasedUses);
}
}
// For documentation see Autopar.td
struct FixupBufferDeallocationPass
: public FixupBufferDeallocationBase<FixupBufferDeallocationPass> {
@@ -552,23 +597,21 @@ struct FixupBufferDeallocationPass
auto module = getOperation();
std::vector<Operation *> ops;
// All buffers allocated and either made into a future, directly
// or as a result of being returned by a task, are managed by the
// DFR runtime system's reference counting.
module.walk([&](RT::WorkFunctionReturnOp retOp) {
for (auto &use :
llvm::make_early_inc_range(retOp.getOperands().front().getUses()))
if (isa<mlir::memref::DeallocOp>(use.getOwner()))
ops.push_back(use.getOwner());
module.walk([&](mlir::memref::DeallocOp op) {
Value alloc = op.getOperand();
DenseSet<OpOperand *> aliasedUses;
getAliasedUses(alloc, aliasedUses);
for (auto use : aliasedUses)
if (isa<RT::WorkFunctionReturnOp, RT::MakeReadyFutureOp>(
use->getOwner())) {
ops.push_back(op);
return;
}
});
module.walk([&](RT::MakeReadyFutureOp mrfOp) {
for (auto &use :
llvm::make_early_inc_range(mrfOp.getOperands().front().getUses()))
if (isa<mlir::memref::DeallocOp>(use.getOwner()))
ops.push_back(use.getOwner());
});
for (auto op : ops)
for (auto op : ops) {
op->erase();
}
}
FixupBufferDeallocationPass(bool debug) : debug(debug){};

View File

@@ -9,6 +9,7 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
#include <mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h>
#include <mlir/Dialect/Func/IR/FuncOps.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/Transforms/RegionUtils.h>
@@ -21,9 +22,10 @@ using namespace mlir::concretelang::RT;
// using namespace mlir::tensor;
namespace {
struct DataflowTaskOpBufferizationInterface
struct DerefWorkFunctionArgumentPtrPlaceholderOpBufferizationInterface
: public BufferizableOpInterface::ExternalModel<
DataflowTaskOpBufferizationInterface, DataflowTaskOp> {
DerefWorkFunctionArgumentPtrPlaceholderOpBufferizationInterface,
DerefWorkFunctionArgumentPtrPlaceholderOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return false;
@@ -44,22 +46,20 @@ struct DataflowTaskOpBufferizationInterface
return BufferRelation::None;
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
LogicalResult bufferize(Operation *bop, RewriterBase &rewriter,
const BufferizationOptions &options) const {
DataflowTaskOp taskOp = cast<DataflowTaskOp>(op);
DerefWorkFunctionArgumentPtrPlaceholderOp op =
cast<DerefWorkFunctionArgumentPtrPlaceholderOp>(bop);
auto isTensorType = [](Type t) { return t.isa<TensorType>(); };
bool hasTensorResult = llvm::any_of(taskOp.getResultTypes(), isTensorType);
bool hasTensorOperand =
llvm::any_of(taskOp.getOperandTypes(), isTensorType);
bool hasTensorResult = llvm::any_of(op->getResultTypes(), isTensorType);
bool hasTensorOperand = llvm::any_of(op->getOperandTypes(), isTensorType);
if (!hasTensorResult && !hasTensorOperand)
return success();
SmallVector<mlir::Value, 2> newOperands;
rewriter.setInsertionPoint(taskOp.getBody(), taskOp.getBody()->begin());
for (OpOperand &opOperand : op->getOpOperands()) {
Value oldOperandValue = opOperand.get();
@@ -72,43 +72,11 @@ struct DataflowTaskOpBufferizationInterface
Value buffer = bufferOrErr.getValue();
newOperands.push_back(buffer);
Value tensor =
rewriter.create<bufferization::ToTensorOp>(buffer.getLoc(), buffer);
replaceAllUsesInRegionWith(oldOperandValue, tensor,
taskOp.getBodyRegion());
} else {
newOperands.push_back(opOperand.get());
}
}
if (hasTensorResult) {
WalkResult wr = taskOp.walk([&](DataflowYieldOp yield) {
SmallVector<Value, 2> yieldValues;
for (OpOperand &yieldOperand : yield.getOperation()->getOpOperands())
if (yieldOperand.get().getType().isa<TensorType>()) {
FailureOr<Value> bufferOrErr =
bufferization::getBuffer(rewriter, yieldOperand.get(), options);
if (failed(bufferOrErr))
return WalkResult::interrupt();
yieldValues.push_back(bufferOrErr.getValue());
} else {
yieldValues.push_back(yieldOperand.get());
}
rewriter.setInsertionPointAfter(yield);
rewriter.replaceOpWithNewOp<DataflowYieldOp>(yield.getOperation(),
yieldValues);
return WalkResult::advance();
});
if (wr.wasInterrupted())
return failure();
}
SmallVector<mlir::Type, 2> newResultTypes;
for (OpResult res : op->getResults()) {
@@ -120,17 +88,239 @@ struct DataflowTaskOpBufferizationInterface
}
}
rewriter.setInsertionPoint(taskOp);
DataflowTaskOp newTaskOp = rewriter.create<DataflowTaskOp>(
taskOp.getLoc(), newResultTypes, newOperands);
rewriter.setInsertionPoint(op);
DerefWorkFunctionArgumentPtrPlaceholderOp newOp =
rewriter.create<DerefWorkFunctionArgumentPtrPlaceholderOp>(
op.getLoc(), newResultTypes, newOperands);
newTaskOp.getRegion().takeBody(taskOp.getRegion());
replaceOpWithBufferizedValues(rewriter, op, newTaskOp->getResults());
replaceOpWithBufferizedValues(rewriter, op, newOp->getResults());
return success();
}
};
struct MakeReadyFutureOpBufferizationInterface
: public BufferizableOpInterface::ExternalModel<
MakeReadyFutureOpBufferizationInterface, MakeReadyFutureOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return false;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return false;
}
SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return {};
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
const AnalysisState &state) const {
return BufferRelation::None;
}
LogicalResult bufferize(Operation *bop, RewriterBase &rewriter,
const BufferizationOptions &options) const {
MakeReadyFutureOp op = cast<MakeReadyFutureOp>(bop);
auto isTensorType = [](Type t) { return t.isa<TensorType>(); };
bool hasTensorResult = llvm::any_of(op->getResultTypes(), isTensorType);
bool hasTensorOperand = llvm::any_of(op->getOperandTypes(), isTensorType);
if (!hasTensorResult && !hasTensorOperand)
return success();
SmallVector<mlir::Value, 2> newOperands;
for (OpOperand &opOperand : op->getOpOperands()) {
Value oldOperandValue = opOperand.get();
if (oldOperandValue.getType().isa<TensorType>()) {
FailureOr<Value> bufferOrErr =
bufferization::getBuffer(rewriter, opOperand.get(), options);
if (failed(bufferOrErr))
return failure();
Value buffer = bufferOrErr.getValue();
newOperands.push_back(buffer);
} else {
newOperands.push_back(opOperand.get());
}
}
SmallVector<mlir::Type, 2> newResultTypes;
for (OpResult res : op->getResults()) {
if (TensorType t = res.getType().dyn_cast<TensorType>()) {
BaseMemRefType memrefType = getMemRefType(t, options);
newResultTypes.push_back(memrefType);
} else {
newResultTypes.push_back(res.getType());
}
}
rewriter.setInsertionPoint(op);
MakeReadyFutureOp newOp = rewriter.create<MakeReadyFutureOp>(
op.getLoc(), newResultTypes, newOperands);
replaceOpWithBufferizedValues(rewriter, op, newOp->getResults());
return success();
}
};
struct WorkFunctionReturnOpBufferizationInterface
: public BufferizableOpInterface::ExternalModel<
WorkFunctionReturnOpBufferizationInterface, WorkFunctionReturnOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return false;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return false;
}
SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return {};
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
const AnalysisState &state) const {
return BufferRelation::None;
}
LogicalResult bufferize(Operation *bop, RewriterBase &rewriter,
const BufferizationOptions &options) const {
WorkFunctionReturnOp op = cast<WorkFunctionReturnOp>(bop);
auto isTensorType = [](Type t) { return t.isa<TensorType>(); };
bool hasTensorResult = llvm::any_of(op->getResultTypes(), isTensorType);
bool hasTensorOperand = llvm::any_of(op->getOperandTypes(), isTensorType);
if (!hasTensorResult && !hasTensorOperand)
return success();
SmallVector<mlir::Value, 2> newOperands;
for (OpOperand &opOperand : op->getOpOperands()) {
Value oldOperandValue = opOperand.get();
if (oldOperandValue.getType().isa<TensorType>()) {
FailureOr<Value> bufferOrErr =
bufferization::getBuffer(rewriter, opOperand.get(), options);
if (failed(bufferOrErr))
return failure();
Value buffer = bufferOrErr.getValue();
newOperands.push_back(buffer);
} else {
newOperands.push_back(opOperand.get());
}
}
SmallVector<mlir::Type, 2> newResultTypes;
for (OpResult res : op->getResults()) {
if (TensorType t = res.getType().dyn_cast<TensorType>()) {
BaseMemRefType memrefType = getMemRefType(t, options);
newResultTypes.push_back(memrefType);
} else {
newResultTypes.push_back(res.getType());
}
}
rewriter.setInsertionPoint(op);
WorkFunctionReturnOp newOp = rewriter.create<WorkFunctionReturnOp>(
op.getLoc(), newResultTypes, newOperands);
replaceOpWithBufferizedValues(rewriter, op, newOp->getResults());
return success();
}
};
struct AwaitFutureOpBufferizationInterface
: public BufferizableOpInterface::ExternalModel<
AwaitFutureOpBufferizationInterface, AwaitFutureOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return false;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return false;
}
SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return {};
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
const AnalysisState &state) const {
return BufferRelation::None;
}
LogicalResult bufferize(Operation *bop, RewriterBase &rewriter,
const BufferizationOptions &options) const {
AwaitFutureOp op = cast<AwaitFutureOp>(bop);
auto isTensorType = [](Type t) { return t.isa<TensorType>(); };
bool hasTensorResult = llvm::any_of(op->getResultTypes(), isTensorType);
bool hasTensorOperand = llvm::any_of(op->getOperandTypes(), isTensorType);
if (!hasTensorResult && !hasTensorOperand)
return success();
SmallVector<mlir::Value, 2> newOperands;
for (OpOperand &opOperand : op->getOpOperands()) {
Value oldOperandValue = opOperand.get();
if (oldOperandValue.getType().isa<TensorType>()) {
FailureOr<Value> bufferOrErr =
bufferization::getBuffer(rewriter, opOperand.get(), options);
if (failed(bufferOrErr))
return failure();
Value buffer = bufferOrErr.getValue();
newOperands.push_back(buffer);
} else {
newOperands.push_back(opOperand.get());
}
}
SmallVector<mlir::Type, 2> newResultTypes;
for (OpResult res : op->getResults()) {
if (TensorType t = res.getType().dyn_cast<TensorType>()) {
BaseMemRefType memrefType = getMemRefType(t, options);
newResultTypes.push_back(memrefType);
} else {
newResultTypes.push_back(res.getType());
}
}
rewriter.setInsertionPoint(op);
AwaitFutureOp newOp = rewriter.create<AwaitFutureOp>(
op.getLoc(), newResultTypes, newOperands);
replaceOpWithBufferizedValues(rewriter, op, newOp->getResults());
return success();
}
};
} // namespace
namespace mlir {
@@ -138,7 +328,13 @@ namespace concretelang {
namespace RT {
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx, RTDialect *dialect) {
DataflowTaskOp::attachInterface<DataflowTaskOpBufferizationInterface>(*ctx);
DerefWorkFunctionArgumentPtrPlaceholderOp::attachInterface<
DerefWorkFunctionArgumentPtrPlaceholderOpBufferizationInterface>(*ctx);
AwaitFutureOp::attachInterface<AwaitFutureOpBufferizationInterface>(*ctx);
MakeReadyFutureOp::attachInterface<MakeReadyFutureOpBufferizationInterface>(
*ctx);
WorkFunctionReturnOp::attachInterface<
WorkFunctionReturnOpBufferizationInterface>(*ctx);
});
}
} // namespace RT

View File

@@ -93,8 +93,8 @@ static inline size_t _dfr_find_next_execution_locality() {
/// hpx::future<void*> and the size of data within the future. After
/// that come NUM_OUTPUTS pairs of hpx::future<void*>* and size_t for
/// the returns.
void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
...) {
void _dfr_create_async_task(wfnptr wfn, void *ctx, size_t num_params,
size_t num_outputs, ...) {
std::vector<void *> refcounted_futures;
std::vector<size_t> param_sizes;
std::vector<uint64_t> param_types;
@@ -104,16 +104,16 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
va_list args;
va_start(args, num_outputs);
for (size_t i = 0; i < num_params; ++i) {
refcounted_futures.push_back(va_arg(args, void *));
param_sizes.push_back(va_arg(args, uint64_t));
param_types.push_back(va_arg(args, uint64_t));
}
for (size_t i = 0; i < num_outputs; ++i) {
outputs.push_back(va_arg(args, void *));
output_sizes.push_back(va_arg(args, uint64_t));
output_types.push_back(va_arg(args, uint64_t));
}
for (size_t i = 0; i < num_params; ++i) {
refcounted_futures.push_back(va_arg(args, void *));
param_sizes.push_back(va_arg(args, uint64_t));
param_types.push_back(va_arg(args, uint64_t));
}
va_end(args);
// Take a reference on each future argument
@@ -140,12 +140,12 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
case 0:
oodf = std::move(hpx::dataflow(
[wfnname, param_sizes, param_types, output_sizes, output_types,
gcc_target]()
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
gcc_target,
ctx]() -> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
std::vector<void *> params = {};
mlir::concretelang::dfr::OpaqueInputData oid(
wfnname, params, param_sizes, param_types, output_sizes,
output_types);
output_types, ctx);
return gcc_target->execute_task(oid);
}));
break;
@@ -153,12 +153,12 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
case 1:
oodf = std::move(hpx::dataflow(
[wfnname, param_sizes, param_types, output_sizes, output_types,
gcc_target](hpx::shared_future<void *> param0)
gcc_target, ctx](hpx::shared_future<void *> param0)
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
std::vector<void *> params = {param0.get()};
mlir::concretelang::dfr::OpaqueInputData oid(
wfnname, params, param_sizes, param_types, output_sizes,
output_types);
output_types, ctx);
return gcc_target->execute_task(oid);
},
*((dfr_refcounted_future_p)refcounted_futures[0])->future));
@@ -167,13 +167,13 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
case 2:
oodf = std::move(hpx::dataflow(
[wfnname, param_sizes, param_types, output_sizes, output_types,
gcc_target](hpx::shared_future<void *> param0,
hpx::shared_future<void *> param1)
gcc_target, ctx](hpx::shared_future<void *> param0,
hpx::shared_future<void *> param1)
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
std::vector<void *> params = {param0.get(), param1.get()};
mlir::concretelang::dfr::OpaqueInputData oid(
wfnname, params, param_sizes, param_types, output_sizes,
output_types);
output_types, ctx);
return gcc_target->execute_task(oid);
},
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
@@ -183,15 +183,15 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
case 3:
oodf = std::move(hpx::dataflow(
[wfnname, param_sizes, param_types, output_sizes, output_types,
gcc_target](hpx::shared_future<void *> param0,
hpx::shared_future<void *> param1,
hpx::shared_future<void *> param2)
gcc_target, ctx](hpx::shared_future<void *> param0,
hpx::shared_future<void *> param1,
hpx::shared_future<void *> param2)
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
std::vector<void *> params = {param0.get(), param1.get(),
param2.get()};
mlir::concretelang::dfr::OpaqueInputData oid(
wfnname, params, param_sizes, param_types, output_sizes,
output_types);
output_types, ctx);
return gcc_target->execute_task(oid);
},
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
@@ -202,16 +202,16 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
case 4:
oodf = std::move(hpx::dataflow(
[wfnname, param_sizes, param_types, output_sizes, output_types,
gcc_target](hpx::shared_future<void *> param0,
hpx::shared_future<void *> param1,
hpx::shared_future<void *> param2,
hpx::shared_future<void *> param3)
gcc_target, ctx](hpx::shared_future<void *> param0,
hpx::shared_future<void *> param1,
hpx::shared_future<void *> param2,
hpx::shared_future<void *> param3)
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
std::vector<void *> params = {param0.get(), param1.get(),
param2.get(), param3.get()};
mlir::concretelang::dfr::OpaqueInputData oid(
wfnname, params, param_sizes, param_types, output_sizes,
output_types);
output_types, ctx);
return gcc_target->execute_task(oid);
},
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
@@ -223,18 +223,18 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
case 5:
oodf = std::move(hpx::dataflow(
[wfnname, param_sizes, param_types, output_sizes, output_types,
gcc_target](hpx::shared_future<void *> param0,
hpx::shared_future<void *> param1,
hpx::shared_future<void *> param2,
hpx::shared_future<void *> param3,
hpx::shared_future<void *> param4)
gcc_target, ctx](hpx::shared_future<void *> param0,
hpx::shared_future<void *> param1,
hpx::shared_future<void *> param2,
hpx::shared_future<void *> param3,
hpx::shared_future<void *> param4)
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
std::vector<void *> params = {param0.get(), param1.get(),
param2.get(), param3.get(),
param4.get()};
mlir::concretelang::dfr::OpaqueInputData oid(
wfnname, params, param_sizes, param_types, output_sizes,
output_types);
output_types, ctx);
return gcc_target->execute_task(oid);
},
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
@@ -247,19 +247,19 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
case 6:
oodf = std::move(hpx::dataflow(
[wfnname, param_sizes, param_types, output_sizes, output_types,
gcc_target](hpx::shared_future<void *> param0,
hpx::shared_future<void *> param1,
hpx::shared_future<void *> param2,
hpx::shared_future<void *> param3,
hpx::shared_future<void *> param4,
hpx::shared_future<void *> param5)
gcc_target, ctx](hpx::shared_future<void *> param0,
hpx::shared_future<void *> param1,
hpx::shared_future<void *> param2,
hpx::shared_future<void *> param3,
hpx::shared_future<void *> param4,
hpx::shared_future<void *> param5)
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
std::vector<void *> params = {param0.get(), param1.get(),
param2.get(), param3.get(),
param4.get(), param5.get()};
mlir::concretelang::dfr::OpaqueInputData oid(
wfnname, params, param_sizes, param_types, output_sizes,
output_types);
output_types, ctx);
return gcc_target->execute_task(oid);
},
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
@@ -273,20 +273,20 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
case 7:
oodf = std::move(hpx::dataflow(
[wfnname, param_sizes, param_types, output_sizes, output_types,
gcc_target](hpx::shared_future<void *> param0,
hpx::shared_future<void *> param1,
hpx::shared_future<void *> param2,
hpx::shared_future<void *> param3,
hpx::shared_future<void *> param4,
hpx::shared_future<void *> param5,
hpx::shared_future<void *> param6)
gcc_target, ctx](hpx::shared_future<void *> param0,
hpx::shared_future<void *> param1,
hpx::shared_future<void *> param2,
hpx::shared_future<void *> param3,
hpx::shared_future<void *> param4,
hpx::shared_future<void *> param5,
hpx::shared_future<void *> param6)
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
std::vector<void *> params = {
param0.get(), param1.get(), param2.get(), param3.get(),
param4.get(), param5.get(), param6.get()};
mlir::concretelang::dfr::OpaqueInputData oid(
wfnname, params, param_sizes, param_types, output_sizes,
output_types);
output_types, ctx);
return gcc_target->execute_task(oid);
},
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
@@ -301,21 +301,21 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
case 8:
oodf = std::move(hpx::dataflow(
[wfnname, param_sizes, param_types, output_sizes, output_types,
gcc_target](hpx::shared_future<void *> param0,
hpx::shared_future<void *> param1,
hpx::shared_future<void *> param2,
hpx::shared_future<void *> param3,
hpx::shared_future<void *> param4,
hpx::shared_future<void *> param5,
hpx::shared_future<void *> param6,
hpx::shared_future<void *> param7)
gcc_target, ctx](hpx::shared_future<void *> param0,
hpx::shared_future<void *> param1,
hpx::shared_future<void *> param2,
hpx::shared_future<void *> param3,
hpx::shared_future<void *> param4,
hpx::shared_future<void *> param5,
hpx::shared_future<void *> param6,
hpx::shared_future<void *> param7)
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
std::vector<void *> params = {
param0.get(), param1.get(), param2.get(), param3.get(),
param4.get(), param5.get(), param6.get(), param7.get()};
mlir::concretelang::dfr::OpaqueInputData oid(
wfnname, params, param_sizes, param_types, output_sizes,
output_types);
output_types, ctx);
return gcc_target->execute_task(oid);
},
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
@@ -331,15 +331,15 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
case 9:
oodf = std::move(hpx::dataflow(
[wfnname, param_sizes, param_types, output_sizes, output_types,
gcc_target](hpx::shared_future<void *> param0,
hpx::shared_future<void *> param1,
hpx::shared_future<void *> param2,
hpx::shared_future<void *> param3,
hpx::shared_future<void *> param4,
hpx::shared_future<void *> param5,
hpx::shared_future<void *> param6,
hpx::shared_future<void *> param7,
hpx::shared_future<void *> param8)
gcc_target, ctx](hpx::shared_future<void *> param0,
hpx::shared_future<void *> param1,
hpx::shared_future<void *> param2,
hpx::shared_future<void *> param3,
hpx::shared_future<void *> param4,
hpx::shared_future<void *> param5,
hpx::shared_future<void *> param6,
hpx::shared_future<void *> param7,
hpx::shared_future<void *> param8)
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
std::vector<void *> params = {
param0.get(), param1.get(), param2.get(),
@@ -347,7 +347,7 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
param6.get(), param7.get(), param8.get()};
mlir::concretelang::dfr::OpaqueInputData oid(
wfnname, params, param_sizes, param_types, output_sizes,
output_types);
output_types, ctx);
return gcc_target->execute_task(oid);
},
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
@@ -364,16 +364,16 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
case 10:
oodf = std::move(hpx::dataflow(
[wfnname, param_sizes, param_types, output_sizes, output_types,
gcc_target](hpx::shared_future<void *> param0,
hpx::shared_future<void *> param1,
hpx::shared_future<void *> param2,
hpx::shared_future<void *> param3,
hpx::shared_future<void *> param4,
hpx::shared_future<void *> param5,
hpx::shared_future<void *> param6,
hpx::shared_future<void *> param7,
hpx::shared_future<void *> param8,
hpx::shared_future<void *> param9)
gcc_target, ctx](hpx::shared_future<void *> param0,
hpx::shared_future<void *> param1,
hpx::shared_future<void *> param2,
hpx::shared_future<void *> param3,
hpx::shared_future<void *> param4,
hpx::shared_future<void *> param5,
hpx::shared_future<void *> param6,
hpx::shared_future<void *> param7,
hpx::shared_future<void *> param8,
hpx::shared_future<void *> param9)
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
std::vector<void *> params = {
param0.get(), param1.get(), param2.get(), param3.get(),
@@ -381,7 +381,7 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
param8.get(), param9.get()};
mlir::concretelang::dfr::OpaqueInputData oid(
wfnname, params, param_sizes, param_types, output_sizes,
output_types);
output_types, ctx);
return gcc_target->execute_task(oid);
},
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
@@ -399,17 +399,17 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
case 11:
oodf = std::move(hpx::dataflow(
[wfnname, param_sizes, param_types, output_sizes, output_types,
gcc_target](hpx::shared_future<void *> param0,
hpx::shared_future<void *> param1,
hpx::shared_future<void *> param2,
hpx::shared_future<void *> param3,
hpx::shared_future<void *> param4,
hpx::shared_future<void *> param5,
hpx::shared_future<void *> param6,
hpx::shared_future<void *> param7,
hpx::shared_future<void *> param8,
hpx::shared_future<void *> param9,
hpx::shared_future<void *> param10)
gcc_target, ctx](hpx::shared_future<void *> param0,
hpx::shared_future<void *> param1,
hpx::shared_future<void *> param2,
hpx::shared_future<void *> param3,
hpx::shared_future<void *> param4,
hpx::shared_future<void *> param5,
hpx::shared_future<void *> param6,
hpx::shared_future<void *> param7,
hpx::shared_future<void *> param8,
hpx::shared_future<void *> param9,
hpx::shared_future<void *> param10)
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
std::vector<void *> params = {
param0.get(), param1.get(), param2.get(), param3.get(),
@@ -417,7 +417,7 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
param8.get(), param9.get(), param10.get()};
mlir::concretelang::dfr::OpaqueInputData oid(
wfnname, params, param_sizes, param_types, output_sizes,
output_types);
output_types, ctx);
return gcc_target->execute_task(oid);
},
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
@@ -436,18 +436,18 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
case 12:
oodf = std::move(hpx::dataflow(
[wfnname, param_sizes, param_types, output_sizes, output_types,
gcc_target](hpx::shared_future<void *> param0,
hpx::shared_future<void *> param1,
hpx::shared_future<void *> param2,
hpx::shared_future<void *> param3,
hpx::shared_future<void *> param4,
hpx::shared_future<void *> param5,
hpx::shared_future<void *> param6,
hpx::shared_future<void *> param7,
hpx::shared_future<void *> param8,
hpx::shared_future<void *> param9,
hpx::shared_future<void *> param10,
hpx::shared_future<void *> param11)
gcc_target, ctx](hpx::shared_future<void *> param0,
hpx::shared_future<void *> param1,
hpx::shared_future<void *> param2,
hpx::shared_future<void *> param3,
hpx::shared_future<void *> param4,
hpx::shared_future<void *> param5,
hpx::shared_future<void *> param6,
hpx::shared_future<void *> param7,
hpx::shared_future<void *> param8,
hpx::shared_future<void *> param9,
hpx::shared_future<void *> param10,
hpx::shared_future<void *> param11)
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
std::vector<void *> params = {
param0.get(), param1.get(), param2.get(), param3.get(),
@@ -455,7 +455,7 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
param8.get(), param9.get(), param10.get(), param11.get()};
mlir::concretelang::dfr::OpaqueInputData oid(
wfnname, params, param_sizes, param_types, output_sizes,
output_types);
output_types, ctx);
return gcc_target->execute_task(oid);
},
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
@@ -475,19 +475,19 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
case 13:
oodf = std::move(hpx::dataflow(
[wfnname, param_sizes, param_types, output_sizes, output_types,
gcc_target](hpx::shared_future<void *> param0,
hpx::shared_future<void *> param1,
hpx::shared_future<void *> param2,
hpx::shared_future<void *> param3,
hpx::shared_future<void *> param4,
hpx::shared_future<void *> param5,
hpx::shared_future<void *> param6,
hpx::shared_future<void *> param7,
hpx::shared_future<void *> param8,
hpx::shared_future<void *> param9,
hpx::shared_future<void *> param10,
hpx::shared_future<void *> param11,
hpx::shared_future<void *> param12)
gcc_target, ctx](hpx::shared_future<void *> param0,
hpx::shared_future<void *> param1,
hpx::shared_future<void *> param2,
hpx::shared_future<void *> param3,
hpx::shared_future<void *> param4,
hpx::shared_future<void *> param5,
hpx::shared_future<void *> param6,
hpx::shared_future<void *> param7,
hpx::shared_future<void *> param8,
hpx::shared_future<void *> param9,
hpx::shared_future<void *> param10,
hpx::shared_future<void *> param11,
hpx::shared_future<void *> param12)
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
std::vector<void *> params = {
param0.get(), param1.get(), param2.get(), param3.get(),
@@ -496,7 +496,7 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
param12.get()};
mlir::concretelang::dfr::OpaqueInputData oid(
wfnname, params, param_sizes, param_types, output_sizes,
output_types);
output_types, ctx);
return gcc_target->execute_task(oid);
},
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
@@ -517,20 +517,20 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
case 14:
oodf = std::move(hpx::dataflow(
[wfnname, param_sizes, param_types, output_sizes, output_types,
gcc_target](hpx::shared_future<void *> param0,
hpx::shared_future<void *> param1,
hpx::shared_future<void *> param2,
hpx::shared_future<void *> param3,
hpx::shared_future<void *> param4,
hpx::shared_future<void *> param5,
hpx::shared_future<void *> param6,
hpx::shared_future<void *> param7,
hpx::shared_future<void *> param8,
hpx::shared_future<void *> param9,
hpx::shared_future<void *> param10,
hpx::shared_future<void *> param11,
hpx::shared_future<void *> param12,
hpx::shared_future<void *> param13)
gcc_target, ctx](hpx::shared_future<void *> param0,
hpx::shared_future<void *> param1,
hpx::shared_future<void *> param2,
hpx::shared_future<void *> param3,
hpx::shared_future<void *> param4,
hpx::shared_future<void *> param5,
hpx::shared_future<void *> param6,
hpx::shared_future<void *> param7,
hpx::shared_future<void *> param8,
hpx::shared_future<void *> param9,
hpx::shared_future<void *> param10,
hpx::shared_future<void *> param11,
hpx::shared_future<void *> param12,
hpx::shared_future<void *> param13)
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
std::vector<void *> params = {
param0.get(), param1.get(), param2.get(), param3.get(),
@@ -539,7 +539,7 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
param12.get(), param13.get()};
mlir::concretelang::dfr::OpaqueInputData oid(
wfnname, params, param_sizes, param_types, output_sizes,
output_types);
output_types, ctx);
return gcc_target->execute_task(oid);
},
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
@@ -561,21 +561,21 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
case 15:
oodf = std::move(hpx::dataflow(
[wfnname, param_sizes, param_types, output_sizes, output_types,
gcc_target](hpx::shared_future<void *> param0,
hpx::shared_future<void *> param1,
hpx::shared_future<void *> param2,
hpx::shared_future<void *> param3,
hpx::shared_future<void *> param4,
hpx::shared_future<void *> param5,
hpx::shared_future<void *> param6,
hpx::shared_future<void *> param7,
hpx::shared_future<void *> param8,
hpx::shared_future<void *> param9,
hpx::shared_future<void *> param10,
hpx::shared_future<void *> param11,
hpx::shared_future<void *> param12,
hpx::shared_future<void *> param13,
hpx::shared_future<void *> param14)
gcc_target, ctx](hpx::shared_future<void *> param0,
hpx::shared_future<void *> param1,
hpx::shared_future<void *> param2,
hpx::shared_future<void *> param3,
hpx::shared_future<void *> param4,
hpx::shared_future<void *> param5,
hpx::shared_future<void *> param6,
hpx::shared_future<void *> param7,
hpx::shared_future<void *> param8,
hpx::shared_future<void *> param9,
hpx::shared_future<void *> param10,
hpx::shared_future<void *> param11,
hpx::shared_future<void *> param12,
hpx::shared_future<void *> param13,
hpx::shared_future<void *> param14)
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
std::vector<void *> params = {
param0.get(), param1.get(), param2.get(), param3.get(),
@@ -584,7 +584,7 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
param12.get(), param13.get(), param14.get()};
mlir::concretelang::dfr::OpaqueInputData oid(
wfnname, params, param_sizes, param_types, output_sizes,
output_types);
output_types, ctx);
return gcc_target->execute_task(oid);
},
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
@@ -607,22 +607,22 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
case 16:
oodf = std::move(hpx::dataflow(
[wfnname, param_sizes, param_types, output_sizes, output_types,
gcc_target](hpx::shared_future<void *> param0,
hpx::shared_future<void *> param1,
hpx::shared_future<void *> param2,
hpx::shared_future<void *> param3,
hpx::shared_future<void *> param4,
hpx::shared_future<void *> param5,
hpx::shared_future<void *> param6,
hpx::shared_future<void *> param7,
hpx::shared_future<void *> param8,
hpx::shared_future<void *> param9,
hpx::shared_future<void *> param10,
hpx::shared_future<void *> param11,
hpx::shared_future<void *> param12,
hpx::shared_future<void *> param13,
hpx::shared_future<void *> param14,
hpx::shared_future<void *> param15)
gcc_target, ctx](hpx::shared_future<void *> param0,
hpx::shared_future<void *> param1,
hpx::shared_future<void *> param2,
hpx::shared_future<void *> param3,
hpx::shared_future<void *> param4,
hpx::shared_future<void *> param5,
hpx::shared_future<void *> param6,
hpx::shared_future<void *> param7,
hpx::shared_future<void *> param8,
hpx::shared_future<void *> param9,
hpx::shared_future<void *> param10,
hpx::shared_future<void *> param11,
hpx::shared_future<void *> param12,
hpx::shared_future<void *> param13,
hpx::shared_future<void *> param14,
hpx::shared_future<void *> param15)
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
std::vector<void *> params = {
param0.get(), param1.get(), param2.get(), param3.get(),
@@ -631,7 +631,7 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
param12.get(), param13.get(), param14.get(), param15.get()};
mlir::concretelang::dfr::OpaqueInputData oid(
wfnname, params, param_sizes, param_types, output_sizes,
output_types);
output_types, ctx);
return gcc_target->execute_task(oid);
},
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
@@ -655,23 +655,23 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
case 17:
oodf = std::move(hpx::dataflow(
[wfnname, param_sizes, param_types, output_sizes, output_types,
gcc_target](hpx::shared_future<void *> param0,
hpx::shared_future<void *> param1,
hpx::shared_future<void *> param2,
hpx::shared_future<void *> param3,
hpx::shared_future<void *> param4,
hpx::shared_future<void *> param5,
hpx::shared_future<void *> param6,
hpx::shared_future<void *> param7,
hpx::shared_future<void *> param8,
hpx::shared_future<void *> param9,
hpx::shared_future<void *> param10,
hpx::shared_future<void *> param11,
hpx::shared_future<void *> param12,
hpx::shared_future<void *> param13,
hpx::shared_future<void *> param14,
hpx::shared_future<void *> param15,
hpx::shared_future<void *> param16)
gcc_target, ctx](hpx::shared_future<void *> param0,
hpx::shared_future<void *> param1,
hpx::shared_future<void *> param2,
hpx::shared_future<void *> param3,
hpx::shared_future<void *> param4,
hpx::shared_future<void *> param5,
hpx::shared_future<void *> param6,
hpx::shared_future<void *> param7,
hpx::shared_future<void *> param8,
hpx::shared_future<void *> param9,
hpx::shared_future<void *> param10,
hpx::shared_future<void *> param11,
hpx::shared_future<void *> param12,
hpx::shared_future<void *> param13,
hpx::shared_future<void *> param14,
hpx::shared_future<void *> param15,
hpx::shared_future<void *> param16)
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
std::vector<void *> params = {
param0.get(), param1.get(), param2.get(), param3.get(),
@@ -681,7 +681,7 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
param16.get()};
mlir::concretelang::dfr::OpaqueInputData oid(
wfnname, params, param_sizes, param_types, output_sizes,
output_types);
output_types, ctx);
return gcc_target->execute_task(oid);
},
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
@@ -706,24 +706,24 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
case 18:
oodf = std::move(hpx::dataflow(
[wfnname, param_sizes, param_types, output_sizes, output_types,
gcc_target](hpx::shared_future<void *> param0,
hpx::shared_future<void *> param1,
hpx::shared_future<void *> param2,
hpx::shared_future<void *> param3,
hpx::shared_future<void *> param4,
hpx::shared_future<void *> param5,
hpx::shared_future<void *> param6,
hpx::shared_future<void *> param7,
hpx::shared_future<void *> param8,
hpx::shared_future<void *> param9,
hpx::shared_future<void *> param10,
hpx::shared_future<void *> param11,
hpx::shared_future<void *> param12,
hpx::shared_future<void *> param13,
hpx::shared_future<void *> param14,
hpx::shared_future<void *> param15,
hpx::shared_future<void *> param16,
hpx::shared_future<void *> param17)
gcc_target, ctx](hpx::shared_future<void *> param0,
hpx::shared_future<void *> param1,
hpx::shared_future<void *> param2,
hpx::shared_future<void *> param3,
hpx::shared_future<void *> param4,
hpx::shared_future<void *> param5,
hpx::shared_future<void *> param6,
hpx::shared_future<void *> param7,
hpx::shared_future<void *> param8,
hpx::shared_future<void *> param9,
hpx::shared_future<void *> param10,
hpx::shared_future<void *> param11,
hpx::shared_future<void *> param12,
hpx::shared_future<void *> param13,
hpx::shared_future<void *> param14,
hpx::shared_future<void *> param15,
hpx::shared_future<void *> param16,
hpx::shared_future<void *> param17)
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
std::vector<void *> params = {
param0.get(), param1.get(), param2.get(), param3.get(),
@@ -733,7 +733,7 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
param16.get(), param17.get()};
mlir::concretelang::dfr::OpaqueInputData oid(
wfnname, params, param_sizes, param_types, output_sizes,
output_types);
output_types, ctx);
return gcc_target->execute_task(oid);
},
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
@@ -759,25 +759,25 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
case 19:
oodf = std::move(hpx::dataflow(
[wfnname, param_sizes, param_types, output_sizes, output_types,
gcc_target](hpx::shared_future<void *> param0,
hpx::shared_future<void *> param1,
hpx::shared_future<void *> param2,
hpx::shared_future<void *> param3,
hpx::shared_future<void *> param4,
hpx::shared_future<void *> param5,
hpx::shared_future<void *> param6,
hpx::shared_future<void *> param7,
hpx::shared_future<void *> param8,
hpx::shared_future<void *> param9,
hpx::shared_future<void *> param10,
hpx::shared_future<void *> param11,
hpx::shared_future<void *> param12,
hpx::shared_future<void *> param13,
hpx::shared_future<void *> param14,
hpx::shared_future<void *> param15,
hpx::shared_future<void *> param16,
hpx::shared_future<void *> param17,
hpx::shared_future<void *> param18)
gcc_target, ctx](hpx::shared_future<void *> param0,
hpx::shared_future<void *> param1,
hpx::shared_future<void *> param2,
hpx::shared_future<void *> param3,
hpx::shared_future<void *> param4,
hpx::shared_future<void *> param5,
hpx::shared_future<void *> param6,
hpx::shared_future<void *> param7,
hpx::shared_future<void *> param8,
hpx::shared_future<void *> param9,
hpx::shared_future<void *> param10,
hpx::shared_future<void *> param11,
hpx::shared_future<void *> param12,
hpx::shared_future<void *> param13,
hpx::shared_future<void *> param14,
hpx::shared_future<void *> param15,
hpx::shared_future<void *> param16,
hpx::shared_future<void *> param17,
hpx::shared_future<void *> param18)
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
std::vector<void *> params = {
param0.get(), param1.get(), param2.get(), param3.get(),
@@ -787,7 +787,7 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
param16.get(), param17.get(), param18.get()};
mlir::concretelang::dfr::OpaqueInputData oid(
wfnname, params, param_sizes, param_types, output_sizes,
output_types);
output_types, ctx);
return gcc_target->execute_task(oid);
},
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
@@ -814,26 +814,26 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
case 20:
oodf = std::move(hpx::dataflow(
[wfnname, param_sizes, param_types, output_sizes, output_types,
gcc_target](hpx::shared_future<void *> param0,
hpx::shared_future<void *> param1,
hpx::shared_future<void *> param2,
hpx::shared_future<void *> param3,
hpx::shared_future<void *> param4,
hpx::shared_future<void *> param5,
hpx::shared_future<void *> param6,
hpx::shared_future<void *> param7,
hpx::shared_future<void *> param8,
hpx::shared_future<void *> param9,
hpx::shared_future<void *> param10,
hpx::shared_future<void *> param11,
hpx::shared_future<void *> param12,
hpx::shared_future<void *> param13,
hpx::shared_future<void *> param14,
hpx::shared_future<void *> param15,
hpx::shared_future<void *> param16,
hpx::shared_future<void *> param17,
hpx::shared_future<void *> param18,
hpx::shared_future<void *> param19)
gcc_target, ctx](hpx::shared_future<void *> param0,
hpx::shared_future<void *> param1,
hpx::shared_future<void *> param2,
hpx::shared_future<void *> param3,
hpx::shared_future<void *> param4,
hpx::shared_future<void *> param5,
hpx::shared_future<void *> param6,
hpx::shared_future<void *> param7,
hpx::shared_future<void *> param8,
hpx::shared_future<void *> param9,
hpx::shared_future<void *> param10,
hpx::shared_future<void *> param11,
hpx::shared_future<void *> param12,
hpx::shared_future<void *> param13,
hpx::shared_future<void *> param14,
hpx::shared_future<void *> param15,
hpx::shared_future<void *> param16,
hpx::shared_future<void *> param17,
hpx::shared_future<void *> param18,
hpx::shared_future<void *> param19)
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
std::vector<void *> params = {
param0.get(), param1.get(), param2.get(), param3.get(),
@@ -843,7 +843,7 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
param16.get(), param17.get(), param18.get(), param19.get()};
mlir::concretelang::dfr::OpaqueInputData oid(
wfnname, params, param_sizes, param_types, output_sizes,
output_types);
output_types, ctx);
return gcc_target->execute_task(oid);
},
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
@@ -969,6 +969,7 @@ void _dfr_set_use_omp(bool use_omp) {
bool _dfr_is_jit() { return mlir::concretelang::dfr::is_jit_p; }
bool _dfr_is_root_node() { return mlir::concretelang::dfr::is_root_node_p; }
bool _dfr_use_omp() { return mlir::concretelang::dfr::use_omp_p; }
bool _dfr_is_distributed() { return num_nodes > 1; }
} // namespace dfr
} // namespace concretelang
} // namespace mlir
@@ -1002,6 +1003,7 @@ static inline void _dfr_stop_impl() {
}
static inline void _dfr_start_impl(int argc, char *argv[]) {
BEGIN_TIME(&mlir::concretelang::dfr::init_timer);
mlir::concretelang::dfr::dl_handle = dlopen(nullptr, RTLD_NOW);
// If OpenMP is to be used, we need to force its initialization
@@ -1126,15 +1128,15 @@ static inline void _dfr_start_impl(int argc, char *argv[]) {
mlir::concretelang::dfr::num_nodes)
.get();
}
END_TIME(&mlir::concretelang::dfr::init_timer, "Initialization");
}
/* Start/stop functions to be called from within user code (or during
JIT invocation). These serve to pause/resume the runtime
scheduler and to clean up used resources. */
void _dfr_start(int64_t use_dfr_p) {
void _dfr_start(int64_t use_dfr_p, void *ctx) {
BEGIN_TIME(&mlir::concretelang::dfr::whole_timer);
if (use_dfr_p) {
BEGIN_TIME(&mlir::concretelang::dfr::init_timer);
// The first invocation will initialise the runtime. As each call to
// _dfr_start is matched with _dfr_stop, if this is not hte first,
// we need to resume the HPX runtime.
@@ -1146,16 +1148,11 @@ void _dfr_start(int64_t use_dfr_p) {
if (mlir::concretelang::dfr::init_guard.compare_exchange_strong(
expected, mlir::concretelang::dfr::active))
_dfr_start_impl(0, nullptr);
END_TIME(&mlir::concretelang::dfr::init_timer, "Initialization");
assert(mlir::concretelang::dfr::init_guard ==
mlir::concretelang::dfr::active &&
"DFR runtime failed to initialise");
if (use_dfr_p == 1) {
BEGIN_TIME(&mlir::concretelang::dfr::compute_timer);
}
// If this is not the root node in a non-JIT execution, then this
// node should only run the scheduler for any incoming work until
// termination is flagged. If this is JIT, we need to run the
@@ -1164,28 +1161,24 @@ void _dfr_start(int64_t use_dfr_p) {
!mlir::concretelang::dfr::_dfr_is_jit())
_dfr_stop_impl();
}
}
// Startup entry point when a RuntimeContext is used
void _dfr_start_c(int64_t use_dfr_p, void *ctx) {
_dfr_start(2);
// If DFR is used and a runtime context is needed, and execution is
// distributed, then broadcast from root to all compute nodes.
if (use_dfr_p && (mlir::concretelang::dfr::num_nodes > 1) &&
(ctx || !mlir::concretelang::dfr::_dfr_is_root_node())) {
BEGIN_TIME(&mlir::concretelang::dfr::broadcast_timer);
new mlir::concretelang::dfr::RuntimeContextManager();
mlir::concretelang::dfr::_dfr_node_level_runtime_context_manager
->setContext(ctx);
if (use_dfr_p) {
if (mlir::concretelang::dfr::num_nodes > 1) {
BEGIN_TIME(&mlir::concretelang::dfr::broadcast_timer);
new mlir::concretelang::dfr::RuntimeContextManager();
mlir::concretelang::dfr::_dfr_node_level_runtime_context_manager
->setContext(ctx);
// If this is not JIT, then the remote nodes never reach _dfr_stop,
// so root should not instantiate this barrier.
if (mlir::concretelang::dfr::_dfr_is_root_node() &&
mlir::concretelang::dfr::_dfr_is_jit())
mlir::concretelang::dfr::_dfr_startup_barrier->wait();
END_TIME(&mlir::concretelang::dfr::broadcast_timer, "Key broadcasting");
}
BEGIN_TIME(&mlir::concretelang::dfr::compute_timer);
// If this is not JIT, then the remote nodes never reach _dfr_stop,
// so root should not instantiate this barrier.
if (mlir::concretelang::dfr::_dfr_is_root_node() &&
mlir::concretelang::dfr::_dfr_is_jit())
mlir::concretelang::dfr::_dfr_startup_barrier->wait();
END_TIME(&mlir::concretelang::dfr::broadcast_timer, "Key broadcasting");
}
BEGIN_TIME(&mlir::concretelang::dfr::compute_timer);
}
// This function cannot be used to terminate the runtime as it is
@@ -1216,8 +1209,8 @@ void _dfr_stop(int64_t use_dfr_p) {
mlir::concretelang::dfr::_dfr_node_level_runtime_context_manager
->clearContext();
}
END_TIME(&mlir::concretelang::dfr::compute_timer, "Compute");
}
END_TIME(&mlir::concretelang::dfr::compute_timer, "Compute");
END_TIME(&mlir::concretelang::dfr::whole_timer, "Total execution");
}
@@ -1298,6 +1291,7 @@ namespace dfr {
namespace {
static bool is_jit_p = false;
static bool use_omp_p = false;
static size_t num_nodes = 1;
static struct timespec compute_timer;
} // namespace
@@ -1307,15 +1301,15 @@ void _dfr_set_use_omp(bool use_omp) { use_omp_p = use_omp; }
bool _dfr_is_jit() { return is_jit_p; }
bool _dfr_is_root_node() { return true; }
bool _dfr_use_omp() { return use_omp_p; }
bool _dfr_is_distributed() { return num_nodes > 1; }
} // namespace dfr
} // namespace concretelang
} // namespace mlir
void _dfr_start(int64_t use_dfr_p) {
void _dfr_start(int64_t use_dfr_p, void *ctx) {
BEGIN_TIME(&mlir::concretelang::dfr::compute_timer);
}
void _dfr_start_c(int64_t use_dfr_p, void *ctx) { _dfr_start(2); }
void _dfr_stop(int64_t use_dfr_p) {
END_TIME(&mlir::concretelang::dfr::compute_timer, "Compute");
}

View File

@@ -153,6 +153,8 @@ mlir::LogicalResult autopar(mlir::MLIRContext &context, mlir::ModuleOp &module,
addPotentiallyNestedPass(
pm, mlir::concretelang::createBuildDataflowTaskGraphPass(), enablePass);
addPotentiallyNestedPass(
pm, mlir::concretelang::createLowerDataflowTasksPass(), enablePass);
return pm.run(module.getOperation());
}
@@ -294,6 +296,9 @@ lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module,
mlir::bufferization::createOneShotBufferizePass(bufferizationOptions);
addPotentiallyNestedPass(pm, std::move(comprBuffPass), enablePass);
addPotentiallyNestedPass(
pm, mlir::concretelang::createBufferizeDataflowTaskOpsPass(), enablePass);
if (parallelizeLoops) {
addPotentiallyNestedPass(pm, mlir::concretelang::createForLoopToParallel(),
enablePass);
@@ -305,14 +310,16 @@ lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module,
// Lower affine
addPotentiallyNestedPass(pm, mlir::createLowerAffinePass(), enablePass);
// Lower Dataflow tasks to DRF
// Finalize the lowering of RT/DFR which includes:
// - adding type and typesize information for dependences
// - issue _dfr_start and _dfr_stop calls to start/stop the runtime
// - remove deallocation calls for buffers managed through refcounting
addPotentiallyNestedPass(
pm, mlir::concretelang::createFixupDataflowTaskOpsPass(), enablePass);
addPotentiallyNestedPass(
pm, mlir::concretelang::createLowerDataflowTasksPass(), enablePass);
// Use the buffer deallocation interface to insert future deallocation calls
pm, mlir::concretelang::createFinalizeTaskCreationPass(), enablePass);
addPotentiallyNestedPass(
pm, mlir::bufferization::createBufferDeallocationPass(), enablePass);
addPotentiallyNestedPass(pm, mlir::concretelang::createStartStopPass(),
enablePass);
addPotentiallyNestedPass(
pm, mlir::concretelang::createFixupBufferDeallocationPass(), enablePass);

View File

@@ -13,7 +13,7 @@
std::vector<uint64_t> distributed_results;
TEST(DISABLED_Distributed, nn_med_nested) {
TEST(Distributed, nn_med_nested) {
checkedJit(lambda, R"XXX(
func.func @main(%arg0: tensor<200x4x!FHE.eint<4>>) -> tensor<200x8x!FHE.eint<4>> {
%cst = arith.constant dense<"0x01010100010100000001010101000101010101010101010001000101000001010001010100000101000001000001010001000001010100010001000000010100010001010001000001000101010101000100010001000000000100010001000101000001000101010100010001000000000101000100000000000001000100000100000100000001010000010001000101000100010001000100000100000100010101010000000000000000010001010000000100000100010100000100000000010001000101000100000000000101010101000101010101010100010100010100000000000101010100000100010100000001000101000000010101000101000100000101010100010101010000010101010100010000000000000001010101000100010101000001010001010000010001010101000000000000000001000001000000010100000100000101010100010001000000000000010100010101000000010100000100010001010001000000000100010001000101010100010100000001010100010101010100010100010001000001000000000101000101010001000100000101010100000101010100000100010101000100000101000101010100010001000101010100010001010001010000010000010001010000000001000101010001000000000101000000010000010100010001000001000001010101000100010001010100000101000000010001000000000101000101000000010000000001000101010100010001000000000001010000010001000001010101000101010101010100000000000001000100000100000001000000010101010101000000000101010101000100000101000100000000000001000100000101000101010100010000000101000000000100000100000101010000010100000000010000000000010001000100000101010001010101000000000000010000010101010001000000010001010001010000000000000101000000010101010101000001010101000001000001010100000000010001010100000100000101000101010100010001010001000001000100000101000100010100000100010000000101000000010000010001010101010000000101000000010101000001010100000100010001000000000001010000000100010000000000000000000000000001010101010101010101000001010101000001010100000001000101010101010000010101000101010100010101010000010101010100000100000000000101010000000000010101010000000001000000010100000100000001000101010000000001000001000001010001010000010001000101010001010001010101000100010000000100000100010101000000000101010101010001000100000000000101010000010101000001010001010000000001010100000101000001010000000001010101000100010000010101000000000001000101000001010101000101000001000001000000010100010001000101010100010001010000000101000000010001000001000100000101010001000001000001000101010000010001000001000101000000000000000101010000010000000101010100010100010001010101010000000000010001000101010000000001010100000000010001010100010001000001000101000000010100010000010000010001010100010000010001010100010000010100010101010001000100010100010101000100000101010100000100010100000100000000010101000000010001000001010000000101000100000100010101000000010100000101000001010001010100010000000101010000000001010001000000010100010101010001000100010001000001010101000000010001000100000100010101000000000000010100010000000100000000010100010000000100000101010000010101000100010000010100000001000100000000000100000001010101010101000100010001000000010101010100000001000001000001010001000101010100000001010001010100010101000101000000010001010100010101000100000101000101000001000001000001000101010100010001010000000100000101010100000001000000000000010101000100010001000001000001000000000000010100000100000001"> : tensor<200x8xi5>
@@ -106,7 +106,7 @@ func.func @main(%arg0: tensor<200x4x!FHE.eint<4>>) -> tensor<200x8x!FHE.eint<4>>
ASSERT_EXPECTED_FAILURE(lambda.operator()<std::vector<uint64_t>>());
}
TEST(DISABLED_Distributed, nn_med_sequential) {
TEST(Distributed, nn_med_sequential) {
if (mlir::concretelang::dfr::_dfr_is_root_node()) {
checkedJit(lambda, R"XXX(
func.func @main(%arg0: tensor<200x4x!FHE.eint<4>>) -> tensor<200x8x!FHE.eint<4>> {