mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat(compiler): move the lowering of dataflow tasks to RT dialect before bufferization.
This commit is contained in:
@@ -0,0 +1,66 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include <mlir/Dialect/Func/IR/FuncOps.h>
|
||||
#include <mlir/Dialect/Linalg/IR/Linalg.h>
|
||||
#include <mlir/IR/Operation.h>
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
template <typename TypeConverterType>
|
||||
struct FunctionConstantOpConversion
|
||||
: public mlir::OpRewritePattern<mlir::func::ConstantOp> {
|
||||
FunctionConstantOpConversion(mlir::MLIRContext *ctx,
|
||||
TypeConverterType &converter,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: ::mlir::OpRewritePattern<mlir::func::ConstantOp>(ctx, benefit),
|
||||
converter(converter) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::func::ConstantOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto symTab = mlir::SymbolTable::getNearestSymbolTable(op);
|
||||
auto funcOp = mlir::SymbolTable::lookupSymbolIn(symTab, op.getValue());
|
||||
assert(funcOp &&
|
||||
"Function symbol missing in symbol table for function constant op.");
|
||||
mlir::FunctionType funType = mlir::cast<mlir::func::FuncOp>(funcOp)
|
||||
.getFunctionType()
|
||||
.cast<mlir::FunctionType>();
|
||||
typename TypeConverterType::SignatureConversion result(
|
||||
funType.getNumInputs());
|
||||
mlir::SmallVector<mlir::Type, 1> newResults;
|
||||
if (failed(converter.convertSignatureArgs(funType.getInputs(), result)) ||
|
||||
failed(converter.convertTypes(funType.getResults(), newResults)))
|
||||
return mlir::failure();
|
||||
auto newType = mlir::FunctionType::get(
|
||||
rewriter.getContext(), result.getConvertedTypes(), newResults);
|
||||
rewriter.updateRootInPlace(op, [&] { op.getResult().setType(newType); });
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
static bool isLegal(mlir::func::ConstantOp fun,
|
||||
TypeConverterType &converter) {
|
||||
auto symTab = mlir::SymbolTable::getNearestSymbolTable(fun);
|
||||
auto funcOp = mlir::SymbolTable::lookupSymbolIn(symTab, fun.getValue());
|
||||
assert(funcOp &&
|
||||
"Function symbol missing in symbol table for function constant op.");
|
||||
mlir::FunctionType funType = mlir::cast<mlir::func::FuncOp>(funcOp)
|
||||
.getFunctionType()
|
||||
.cast<mlir::FunctionType>();
|
||||
typename TypeConverterType::SignatureConversion result(
|
||||
funType.getNumInputs());
|
||||
mlir::SmallVector<mlir::Type, 1> newResults;
|
||||
if (failed(converter.convertSignatureArgs(funType.getInputs(), result)) ||
|
||||
failed(converter.convertTypes(funType.getResults(), newResults)))
|
||||
return false;
|
||||
auto newType = mlir::FunctionType::get(
|
||||
fun.getContext(), result.getConvertedTypes(), newResults);
|
||||
return newType == fun.getType();
|
||||
}
|
||||
|
||||
private:
|
||||
TypeConverterType &converter;
|
||||
};
|
||||
@@ -22,7 +22,8 @@ createBuildDataflowTaskGraphPass(bool debug = false);
|
||||
std::unique_ptr<mlir::Pass> createLowerDataflowTasksPass(bool debug = false);
|
||||
std::unique_ptr<mlir::Pass>
|
||||
createBufferizeDataflowTaskOpsPass(bool debug = false);
|
||||
std::unique_ptr<mlir::Pass> createFixupDataflowTaskOpsPass(bool debug = false);
|
||||
std::unique_ptr<mlir::Pass> createFinalizeTaskCreationPass(bool debug = false);
|
||||
std::unique_ptr<mlir::Pass> createStartStopPass(bool debug = false);
|
||||
std::unique_ptr<mlir::Pass>
|
||||
createFixupBufferDeallocationPass(bool debug = false);
|
||||
void populateRTToLLVMConversionPatterns(mlir::LLVMTypeConverter &converter,
|
||||
|
||||
@@ -53,23 +53,6 @@ def BufferizeDataflowTaskOps : Pass<"BufferizeDataflowTaskOps", "mlir::ModuleOp"
|
||||
}];
|
||||
}
|
||||
|
||||
def FixupDataflowTaskOps : Pass<"FixupDataflowTaskOps", "mlir::ModuleOp"> {
|
||||
let summary =
|
||||
"Fix DataflowTaskOp(s) before lowering.";
|
||||
|
||||
let description = [{
|
||||
This pass fixes up code changes that intervene between the
|
||||
BuildDataflowTaskGraph pass and the lowering of the taskgraph to
|
||||
LLVMIR and calls to the DFR runtime system.
|
||||
|
||||
In particular, some operations (e.g., constants, dimension
|
||||
operations, etc.) can be used within the task while only defined
|
||||
outside. In most cases cloning and sinking these operations in the
|
||||
task is the simplest to avoid adding dependences.
|
||||
|
||||
}];
|
||||
}
|
||||
|
||||
def LowerDataflowTasks : Pass<"LowerDataflowTasks", "mlir::ModuleOp"> {
|
||||
let summary =
|
||||
"Outline the body of a DataflowTaskOp into a separate function which will serve as a task work function and lower the task graph to RT.";
|
||||
@@ -82,6 +65,30 @@ def LowerDataflowTasks : Pass<"LowerDataflowTasks", "mlir::ModuleOp"> {
|
||||
}];
|
||||
}
|
||||
|
||||
def FinalizeTaskCreation : Pass<"FinalizeTaskCreation", "mlir::ModuleOp"> {
|
||||
let summary =
|
||||
"Finalize the CreateAsyncTaskOp ops.";
|
||||
|
||||
let description = [{
|
||||
This pass adds the lower level information missing in
|
||||
CreateAsyncTaskOp, in particular the type sizes and if required
|
||||
passing the runtime context.
|
||||
}];
|
||||
}
|
||||
|
||||
def StartStop : Pass<"StartStop", "mlir::ModuleOp"> {
|
||||
let summary =
|
||||
"Issue calls to start/stop the runtime system.";
|
||||
|
||||
let description = [{
|
||||
This pass adds calls to _dfr_start and _dfr_stop which
|
||||
respectively initialize/start and pause the runtime system. The
|
||||
start function further distributes the evaluation keys to
|
||||
compute nodes when required and the stop function clears the
|
||||
execution context.
|
||||
}];
|
||||
}
|
||||
|
||||
def FixupBufferDeallocation : Pass<"FixupBufferDeallocation", "mlir::ModuleOp"> {
|
||||
let summary =
|
||||
"Prevent deallocation of buffers returned as futures by tasks.";
|
||||
|
||||
@@ -24,6 +24,7 @@ void _dfr_set_use_omp(bool);
|
||||
bool _dfr_is_jit();
|
||||
bool _dfr_is_root_node();
|
||||
bool _dfr_use_omp();
|
||||
bool _dfr_is_distributed();
|
||||
|
||||
typedef enum _dfr_task_arg_type {
|
||||
_DFR_TASK_ARG_BASE = 0,
|
||||
|
||||
@@ -69,23 +69,27 @@ struct OpaqueInputData {
|
||||
std::vector<size_t> _param_sizes,
|
||||
std::vector<uint64_t> _param_types,
|
||||
std::vector<size_t> _output_sizes,
|
||||
std::vector<uint64_t> _output_types)
|
||||
std::vector<uint64_t> _output_types, void *_context = nullptr)
|
||||
: wfn_name(_wfn_name), params(std::move(_params)),
|
||||
param_sizes(std::move(_param_sizes)),
|
||||
param_types(std::move(_param_types)),
|
||||
output_sizes(std::move(_output_sizes)),
|
||||
output_types(std::move(_output_types)) {}
|
||||
output_types(std::move(_output_types)), context(_context) {
|
||||
if (_context)
|
||||
params.push_back(_context);
|
||||
}
|
||||
|
||||
OpaqueInputData(const OpaqueInputData &oid)
|
||||
: wfn_name(std::move(oid.wfn_name)), params(std::move(oid.params)),
|
||||
param_sizes(std::move(oid.param_sizes)),
|
||||
param_types(std::move(oid.param_types)),
|
||||
output_sizes(std::move(oid.output_sizes)),
|
||||
output_types(std::move(oid.output_types)) {}
|
||||
output_types(std::move(oid.output_types)), context(oid.context) {}
|
||||
|
||||
friend class hpx::serialization::access;
|
||||
template <class Archive> void load(Archive &ar, const unsigned int version) {
|
||||
ar >> wfn_name;
|
||||
bool has_context;
|
||||
ar >> wfn_name >> has_context;
|
||||
ar >> param_sizes >> param_types;
|
||||
ar >> output_sizes >> output_types;
|
||||
for (size_t p = 0; p < param_sizes.size(); ++p) {
|
||||
@@ -114,27 +118,22 @@ struct OpaqueInputData {
|
||||
static_cast<StridedMemRefType<char, 1> *>(params[p])->basePtr = nullptr;
|
||||
static_cast<StridedMemRefType<char, 1> *>(params[p])->data = data;
|
||||
} break;
|
||||
case _DFR_TASK_ARG_CONTEXT: {
|
||||
// The copied pointer is meaningless - TODO: if the context
|
||||
// can change dynamically (e.g., different evaluation keys)
|
||||
// then this needs updating by passing key ids and retrieving
|
||||
// adequate keys for the context.
|
||||
delete ((char *)params[p]);
|
||||
params[p] =
|
||||
(void *)_dfr_node_level_runtime_context_manager->getContext();
|
||||
} break;
|
||||
default:
|
||||
HPX_THROW_EXCEPTION(hpx::no_success, "DFR: OpaqueInputData save",
|
||||
"Error: invalid task argument type.");
|
||||
}
|
||||
}
|
||||
if (has_context)
|
||||
params.push_back(
|
||||
(void *)_dfr_node_level_runtime_context_manager->getContext());
|
||||
}
|
||||
template <class Archive>
|
||||
void save(Archive &ar, const unsigned int version) const {
|
||||
ar << wfn_name;
|
||||
bool has_context = (bool)(context != nullptr);
|
||||
ar << wfn_name << has_context;
|
||||
ar << param_sizes << param_types;
|
||||
ar << output_sizes << output_types;
|
||||
for (size_t p = 0; p < params.size(); ++p) {
|
||||
for (size_t p = 0; p < param_sizes.size(); ++p) {
|
||||
// Save the first level of the data structure - if the parameter
|
||||
// is a tensor/memref, there is a second level.
|
||||
ar << hpx::serialization::make_array((char *)params[p], param_sizes[p]);
|
||||
@@ -152,10 +151,6 @@ struct OpaqueInputData {
|
||||
ar << hpx::serialization::make_array(
|
||||
mref.data + mref.offset * elementSize, size * elementSize);
|
||||
} break;
|
||||
case _DFR_TASK_ARG_CONTEXT: {
|
||||
// Nothing to do now - TODO: pass key ids if these are not
|
||||
// unique for a computation.
|
||||
} break;
|
||||
default:
|
||||
HPX_THROW_EXCEPTION(hpx::no_success, "DFR: OpaqueInputData save",
|
||||
"Error: invalid task argument type.");
|
||||
@@ -170,6 +165,7 @@ struct OpaqueInputData {
|
||||
std::vector<uint64_t> param_types;
|
||||
std::vector<size_t> output_sizes;
|
||||
std::vector<uint64_t> output_types;
|
||||
void *context;
|
||||
};
|
||||
|
||||
struct OpaqueOutputData {
|
||||
@@ -214,9 +210,6 @@ struct OpaqueOutputData {
|
||||
static_cast<StridedMemRefType<char, 1> *>(outputs[p])->basePtr =
|
||||
nullptr;
|
||||
static_cast<StridedMemRefType<char, 1> *>(outputs[p])->data = data;
|
||||
} break;
|
||||
case _DFR_TASK_ARG_CONTEXT: {
|
||||
|
||||
} break;
|
||||
default:
|
||||
HPX_THROW_EXCEPTION(hpx::no_success, "DFR: OpaqueInputData save",
|
||||
@@ -243,9 +236,6 @@ struct OpaqueOutputData {
|
||||
size *= mref.sizes[r];
|
||||
ar << hpx::serialization::make_array(
|
||||
mref.data + mref.offset * elementSize, size * elementSize);
|
||||
} break;
|
||||
case _DFR_TASK_ARG_CONTEXT: {
|
||||
|
||||
} break;
|
||||
default:
|
||||
HPX_THROW_EXCEPTION(hpx::no_success, "DFR: OpaqueInputData save",
|
||||
@@ -282,121 +272,121 @@ struct GenericComputeServer : component_base<GenericComputeServer> {
|
||||
wfn(output);
|
||||
break;
|
||||
case 1:
|
||||
wfn(inputs.params[0], output);
|
||||
wfn(output, inputs.params[0]);
|
||||
break;
|
||||
case 2:
|
||||
wfn(inputs.params[0], inputs.params[1], output);
|
||||
wfn(output, inputs.params[0], inputs.params[1]);
|
||||
break;
|
||||
case 3:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2], output);
|
||||
wfn(output, inputs.params[0], inputs.params[1], inputs.params[2]);
|
||||
break;
|
||||
case 4:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], output);
|
||||
wfn(output, inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3]);
|
||||
break;
|
||||
case 5:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], output);
|
||||
wfn(output, inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4]);
|
||||
break;
|
||||
case 6:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5], output);
|
||||
wfn(output, inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5]);
|
||||
break;
|
||||
case 7:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
wfn(output, inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], output);
|
||||
inputs.params[6]);
|
||||
break;
|
||||
case 8:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
wfn(output, inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], output);
|
||||
inputs.params[6], inputs.params[7]);
|
||||
break;
|
||||
case 9:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
wfn(output, inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8], output);
|
||||
inputs.params[6], inputs.params[7], inputs.params[8]);
|
||||
break;
|
||||
case 10:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
wfn(output, inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8],
|
||||
inputs.params[9], output);
|
||||
inputs.params[9]);
|
||||
break;
|
||||
case 11:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
wfn(output, inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8],
|
||||
inputs.params[9], inputs.params[10], output);
|
||||
inputs.params[9], inputs.params[10]);
|
||||
break;
|
||||
case 12:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
wfn(output, inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8],
|
||||
inputs.params[9], inputs.params[10], inputs.params[11], output);
|
||||
inputs.params[9], inputs.params[10], inputs.params[11]);
|
||||
break;
|
||||
case 13:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
wfn(output, inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8],
|
||||
inputs.params[9], inputs.params[10], inputs.params[11],
|
||||
inputs.params[12], output);
|
||||
inputs.params[12]);
|
||||
break;
|
||||
case 14:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
wfn(output, inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8],
|
||||
inputs.params[9], inputs.params[10], inputs.params[11],
|
||||
inputs.params[12], inputs.params[13], output);
|
||||
inputs.params[12], inputs.params[13]);
|
||||
break;
|
||||
case 15:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
wfn(output, inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8],
|
||||
inputs.params[9], inputs.params[10], inputs.params[11],
|
||||
inputs.params[12], inputs.params[13], inputs.params[14], output);
|
||||
inputs.params[12], inputs.params[13], inputs.params[14]);
|
||||
break;
|
||||
case 16:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
wfn(output, inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8],
|
||||
inputs.params[9], inputs.params[10], inputs.params[11],
|
||||
inputs.params[12], inputs.params[13], inputs.params[14],
|
||||
inputs.params[15], output);
|
||||
inputs.params[15]);
|
||||
break;
|
||||
case 17:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
wfn(output, inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8],
|
||||
inputs.params[9], inputs.params[10], inputs.params[11],
|
||||
inputs.params[12], inputs.params[13], inputs.params[14],
|
||||
inputs.params[15], inputs.params[16], output);
|
||||
inputs.params[15], inputs.params[16]);
|
||||
break;
|
||||
case 18:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
wfn(output, inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8],
|
||||
inputs.params[9], inputs.params[10], inputs.params[11],
|
||||
inputs.params[12], inputs.params[13], inputs.params[14],
|
||||
inputs.params[15], inputs.params[16], inputs.params[17], output);
|
||||
inputs.params[15], inputs.params[16], inputs.params[17]);
|
||||
break;
|
||||
case 19:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
wfn(output, inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8],
|
||||
inputs.params[9], inputs.params[10], inputs.params[11],
|
||||
inputs.params[12], inputs.params[13], inputs.params[14],
|
||||
inputs.params[15], inputs.params[16], inputs.params[17],
|
||||
inputs.params[18], output);
|
||||
inputs.params[18]);
|
||||
break;
|
||||
case 20:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
wfn(output, inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8],
|
||||
inputs.params[9], inputs.params[10], inputs.params[11],
|
||||
inputs.params[12], inputs.params[13], inputs.params[14],
|
||||
inputs.params[15], inputs.params[16], inputs.params[17],
|
||||
inputs.params[18], inputs.params[19], output);
|
||||
inputs.params[18], inputs.params[19]);
|
||||
break;
|
||||
default:
|
||||
HPX_THROW_EXCEPTION(hpx::no_success,
|
||||
@@ -415,127 +405,127 @@ struct GenericComputeServer : component_base<GenericComputeServer> {
|
||||
wfn(output1, output2);
|
||||
break;
|
||||
case 1:
|
||||
wfn(inputs.params[0], output1, output2);
|
||||
wfn(output1, output2, inputs.params[0]);
|
||||
break;
|
||||
case 2:
|
||||
wfn(inputs.params[0], inputs.params[1], output1, output2);
|
||||
wfn(output1, output2, inputs.params[0], inputs.params[1]);
|
||||
break;
|
||||
case 3:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2], output1,
|
||||
output2);
|
||||
wfn(output1, output2, inputs.params[0], inputs.params[1],
|
||||
inputs.params[2], output1, output2);
|
||||
break;
|
||||
case 4:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], output1, output2);
|
||||
wfn(output1, output2, inputs.params[0], inputs.params[1],
|
||||
inputs.params[2], inputs.params[3]);
|
||||
break;
|
||||
case 5:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], output1, output2);
|
||||
wfn(output1, output2, inputs.params[0], inputs.params[1],
|
||||
inputs.params[2], inputs.params[3], inputs.params[4]);
|
||||
break;
|
||||
case 6:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5], output1,
|
||||
output2);
|
||||
wfn(output1, output2, inputs.params[0], inputs.params[1],
|
||||
inputs.params[2], inputs.params[3], inputs.params[4],
|
||||
inputs.params[5], output1, output2);
|
||||
break;
|
||||
case 7:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], output1, output2);
|
||||
wfn(output1, output2, inputs.params[0], inputs.params[1],
|
||||
inputs.params[2], inputs.params[3], inputs.params[4],
|
||||
inputs.params[5], inputs.params[6]);
|
||||
break;
|
||||
case 8:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], output1, output2);
|
||||
wfn(output1, output2, inputs.params[0], inputs.params[1],
|
||||
inputs.params[2], inputs.params[3], inputs.params[4],
|
||||
inputs.params[5], inputs.params[6], inputs.params[7]);
|
||||
break;
|
||||
case 9:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8], output1,
|
||||
output2);
|
||||
wfn(output1, output2, inputs.params[0], inputs.params[1],
|
||||
inputs.params[2], inputs.params[3], inputs.params[4],
|
||||
inputs.params[5], inputs.params[6], inputs.params[7],
|
||||
inputs.params[8], output1, output2);
|
||||
break;
|
||||
case 10:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8],
|
||||
inputs.params[9], output1, output2);
|
||||
wfn(output1, output2, inputs.params[0], inputs.params[1],
|
||||
inputs.params[2], inputs.params[3], inputs.params[4],
|
||||
inputs.params[5], inputs.params[6], inputs.params[7],
|
||||
inputs.params[8], inputs.params[9]);
|
||||
break;
|
||||
case 11:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8],
|
||||
inputs.params[9], inputs.params[10], output1, output2);
|
||||
wfn(output1, output2, inputs.params[0], inputs.params[1],
|
||||
inputs.params[2], inputs.params[3], inputs.params[4],
|
||||
inputs.params[5], inputs.params[6], inputs.params[7],
|
||||
inputs.params[8], inputs.params[9], inputs.params[10]);
|
||||
break;
|
||||
case 12:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8],
|
||||
inputs.params[9], inputs.params[10], inputs.params[11], output1,
|
||||
output2);
|
||||
wfn(output1, output2, inputs.params[0], inputs.params[1],
|
||||
inputs.params[2], inputs.params[3], inputs.params[4],
|
||||
inputs.params[5], inputs.params[6], inputs.params[7],
|
||||
inputs.params[8], inputs.params[9], inputs.params[10],
|
||||
inputs.params[11], output1, output2);
|
||||
break;
|
||||
case 13:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8],
|
||||
inputs.params[9], inputs.params[10], inputs.params[11],
|
||||
inputs.params[12], output1, output2);
|
||||
wfn(output1, output2, inputs.params[0], inputs.params[1],
|
||||
inputs.params[2], inputs.params[3], inputs.params[4],
|
||||
inputs.params[5], inputs.params[6], inputs.params[7],
|
||||
inputs.params[8], inputs.params[9], inputs.params[10],
|
||||
inputs.params[11], inputs.params[12]);
|
||||
break;
|
||||
case 14:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8],
|
||||
inputs.params[9], inputs.params[10], inputs.params[11],
|
||||
inputs.params[12], inputs.params[13], output1, output2);
|
||||
wfn(output1, output2, inputs.params[0], inputs.params[1],
|
||||
inputs.params[2], inputs.params[3], inputs.params[4],
|
||||
inputs.params[5], inputs.params[6], inputs.params[7],
|
||||
inputs.params[8], inputs.params[9], inputs.params[10],
|
||||
inputs.params[11], inputs.params[12], inputs.params[13]);
|
||||
break;
|
||||
case 15:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8],
|
||||
inputs.params[9], inputs.params[10], inputs.params[11],
|
||||
inputs.params[12], inputs.params[13], inputs.params[14], output1,
|
||||
output2);
|
||||
wfn(output1, output2, inputs.params[0], inputs.params[1],
|
||||
inputs.params[2], inputs.params[3], inputs.params[4],
|
||||
inputs.params[5], inputs.params[6], inputs.params[7],
|
||||
inputs.params[8], inputs.params[9], inputs.params[10],
|
||||
inputs.params[11], inputs.params[12], inputs.params[13],
|
||||
inputs.params[14], output1, output2);
|
||||
break;
|
||||
case 16:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8],
|
||||
inputs.params[9], inputs.params[10], inputs.params[11],
|
||||
inputs.params[12], inputs.params[13], inputs.params[14],
|
||||
inputs.params[15], output1, output2);
|
||||
wfn(output1, output2, inputs.params[0], inputs.params[1],
|
||||
inputs.params[2], inputs.params[3], inputs.params[4],
|
||||
inputs.params[5], inputs.params[6], inputs.params[7],
|
||||
inputs.params[8], inputs.params[9], inputs.params[10],
|
||||
inputs.params[11], inputs.params[12], inputs.params[13],
|
||||
inputs.params[14], inputs.params[15]);
|
||||
break;
|
||||
case 17:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8],
|
||||
inputs.params[9], inputs.params[10], inputs.params[11],
|
||||
inputs.params[12], inputs.params[13], inputs.params[14],
|
||||
inputs.params[15], inputs.params[16], output1, output2);
|
||||
wfn(output1, output2, inputs.params[0], inputs.params[1],
|
||||
inputs.params[2], inputs.params[3], inputs.params[4],
|
||||
inputs.params[5], inputs.params[6], inputs.params[7],
|
||||
inputs.params[8], inputs.params[9], inputs.params[10],
|
||||
inputs.params[11], inputs.params[12], inputs.params[13],
|
||||
inputs.params[14], inputs.params[15], inputs.params[16]);
|
||||
break;
|
||||
case 18:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8],
|
||||
inputs.params[9], inputs.params[10], inputs.params[11],
|
||||
inputs.params[12], inputs.params[13], inputs.params[14],
|
||||
inputs.params[15], inputs.params[16], inputs.params[17], output1,
|
||||
output2);
|
||||
wfn(output1, output2, inputs.params[0], inputs.params[1],
|
||||
inputs.params[2], inputs.params[3], inputs.params[4],
|
||||
inputs.params[5], inputs.params[6], inputs.params[7],
|
||||
inputs.params[8], inputs.params[9], inputs.params[10],
|
||||
inputs.params[11], inputs.params[12], inputs.params[13],
|
||||
inputs.params[14], inputs.params[15], inputs.params[16],
|
||||
inputs.params[17], output1, output2);
|
||||
break;
|
||||
case 19:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8],
|
||||
inputs.params[9], inputs.params[10], inputs.params[11],
|
||||
inputs.params[12], inputs.params[13], inputs.params[14],
|
||||
inputs.params[15], inputs.params[16], inputs.params[17],
|
||||
inputs.params[18], output1, output2);
|
||||
wfn(output1, output2, inputs.params[0], inputs.params[1],
|
||||
inputs.params[2], inputs.params[3], inputs.params[4],
|
||||
inputs.params[5], inputs.params[6], inputs.params[7],
|
||||
inputs.params[8], inputs.params[9], inputs.params[10],
|
||||
inputs.params[11], inputs.params[12], inputs.params[13],
|
||||
inputs.params[14], inputs.params[15], inputs.params[16],
|
||||
inputs.params[17], inputs.params[18]);
|
||||
break;
|
||||
case 20:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8],
|
||||
inputs.params[9], inputs.params[10], inputs.params[11],
|
||||
inputs.params[12], inputs.params[13], inputs.params[14],
|
||||
inputs.params[15], inputs.params[16], inputs.params[17],
|
||||
inputs.params[18], inputs.params[19], output1, output2);
|
||||
wfn(output1, output2, inputs.params[0], inputs.params[1],
|
||||
inputs.params[2], inputs.params[3], inputs.params[4],
|
||||
inputs.params[5], inputs.params[6], inputs.params[7],
|
||||
inputs.params[8], inputs.params[9], inputs.params[10],
|
||||
inputs.params[11], inputs.params[12], inputs.params[13],
|
||||
inputs.params[14], inputs.params[15], inputs.params[16],
|
||||
inputs.params[17], inputs.params[18], inputs.params[19]);
|
||||
break;
|
||||
default:
|
||||
HPX_THROW_EXCEPTION(hpx::no_success,
|
||||
@@ -555,127 +545,127 @@ struct GenericComputeServer : component_base<GenericComputeServer> {
|
||||
wfn(output1, output2, output3);
|
||||
break;
|
||||
case 1:
|
||||
wfn(inputs.params[0], output1, output2, output3);
|
||||
wfn(output1, output2, output3, inputs.params[0]);
|
||||
break;
|
||||
case 2:
|
||||
wfn(inputs.params[0], inputs.params[1], output1, output2, output3);
|
||||
wfn(output1, output2, output3, inputs.params[0], inputs.params[1]);
|
||||
break;
|
||||
case 3:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2], output1,
|
||||
output2, output3);
|
||||
wfn(output1, output2, output3, inputs.params[0], inputs.params[1],
|
||||
inputs.params[2], output1, output2, output3);
|
||||
break;
|
||||
case 4:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], output1, output2, output3);
|
||||
wfn(output1, output2, output3, inputs.params[0], inputs.params[1],
|
||||
inputs.params[2], inputs.params[3]);
|
||||
break;
|
||||
case 5:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], output1, output2, output3);
|
||||
wfn(output1, output2, output3, inputs.params[0], inputs.params[1],
|
||||
inputs.params[2], inputs.params[3], inputs.params[4]);
|
||||
break;
|
||||
case 6:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5], output1,
|
||||
output2, output3);
|
||||
wfn(output1, output2, output3, inputs.params[0], inputs.params[1],
|
||||
inputs.params[2], inputs.params[3], inputs.params[4],
|
||||
inputs.params[5], output1, output2, output3);
|
||||
break;
|
||||
case 7:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], output1, output2, output3);
|
||||
wfn(output1, output2, output3, inputs.params[0], inputs.params[1],
|
||||
inputs.params[2], inputs.params[3], inputs.params[4],
|
||||
inputs.params[5], inputs.params[6]);
|
||||
break;
|
||||
case 8:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], output1, output2, output3);
|
||||
wfn(output1, output2, output3, inputs.params[0], inputs.params[1],
|
||||
inputs.params[2], inputs.params[3], inputs.params[4],
|
||||
inputs.params[5], inputs.params[6], inputs.params[7]);
|
||||
break;
|
||||
case 9:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8], output1,
|
||||
output2, output3);
|
||||
wfn(output1, output2, output3, inputs.params[0], inputs.params[1],
|
||||
inputs.params[2], inputs.params[3], inputs.params[4],
|
||||
inputs.params[5], inputs.params[6], inputs.params[7],
|
||||
inputs.params[8], output1, output2, output3);
|
||||
break;
|
||||
case 10:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8],
|
||||
inputs.params[9], output1, output2, output3);
|
||||
wfn(output1, output2, output3, inputs.params[0], inputs.params[1],
|
||||
inputs.params[2], inputs.params[3], inputs.params[4],
|
||||
inputs.params[5], inputs.params[6], inputs.params[7],
|
||||
inputs.params[8], inputs.params[9]);
|
||||
break;
|
||||
case 11:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8],
|
||||
inputs.params[9], inputs.params[10], output1, output2, output3);
|
||||
wfn(output1, output2, output3, inputs.params[0], inputs.params[1],
|
||||
inputs.params[2], inputs.params[3], inputs.params[4],
|
||||
inputs.params[5], inputs.params[6], inputs.params[7],
|
||||
inputs.params[8], inputs.params[9], inputs.params[10]);
|
||||
break;
|
||||
case 12:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8],
|
||||
inputs.params[9], inputs.params[10], inputs.params[11], output1,
|
||||
output2, output3);
|
||||
wfn(output1, output2, output3, inputs.params[0], inputs.params[1],
|
||||
inputs.params[2], inputs.params[3], inputs.params[4],
|
||||
inputs.params[5], inputs.params[6], inputs.params[7],
|
||||
inputs.params[8], inputs.params[9], inputs.params[10],
|
||||
inputs.params[11], output1, output2, output3);
|
||||
break;
|
||||
case 13:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8],
|
||||
inputs.params[9], inputs.params[10], inputs.params[11],
|
||||
inputs.params[12], output1, output2, output3);
|
||||
wfn(output1, output2, output3, inputs.params[0], inputs.params[1],
|
||||
inputs.params[2], inputs.params[3], inputs.params[4],
|
||||
inputs.params[5], inputs.params[6], inputs.params[7],
|
||||
inputs.params[8], inputs.params[9], inputs.params[10],
|
||||
inputs.params[11], inputs.params[12]);
|
||||
break;
|
||||
case 14:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8],
|
||||
inputs.params[9], inputs.params[10], inputs.params[11],
|
||||
inputs.params[12], inputs.params[13], output1, output2, output3);
|
||||
wfn(output1, output2, output3, inputs.params[0], inputs.params[1],
|
||||
inputs.params[2], inputs.params[3], inputs.params[4],
|
||||
inputs.params[5], inputs.params[6], inputs.params[7],
|
||||
inputs.params[8], inputs.params[9], inputs.params[10],
|
||||
inputs.params[11], inputs.params[12], inputs.params[13]);
|
||||
break;
|
||||
case 15:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8],
|
||||
inputs.params[9], inputs.params[10], inputs.params[11],
|
||||
inputs.params[12], inputs.params[13], inputs.params[14], output1,
|
||||
output2, output3);
|
||||
wfn(output1, output2, output3, inputs.params[0], inputs.params[1],
|
||||
inputs.params[2], inputs.params[3], inputs.params[4],
|
||||
inputs.params[5], inputs.params[6], inputs.params[7],
|
||||
inputs.params[8], inputs.params[9], inputs.params[10],
|
||||
inputs.params[11], inputs.params[12], inputs.params[13],
|
||||
inputs.params[14], output1, output2, output3);
|
||||
break;
|
||||
case 16:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8],
|
||||
inputs.params[9], inputs.params[10], inputs.params[11],
|
||||
inputs.params[12], inputs.params[13], inputs.params[14],
|
||||
inputs.params[15], output1, output2, output3);
|
||||
wfn(output1, output2, output3, inputs.params[0], inputs.params[1],
|
||||
inputs.params[2], inputs.params[3], inputs.params[4],
|
||||
inputs.params[5], inputs.params[6], inputs.params[7],
|
||||
inputs.params[8], inputs.params[9], inputs.params[10],
|
||||
inputs.params[11], inputs.params[12], inputs.params[13],
|
||||
inputs.params[14], inputs.params[15]);
|
||||
break;
|
||||
case 17:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8],
|
||||
inputs.params[9], inputs.params[10], inputs.params[11],
|
||||
inputs.params[12], inputs.params[13], inputs.params[14],
|
||||
inputs.params[15], inputs.params[16], output1, output2, output3);
|
||||
wfn(output1, output2, output3, inputs.params[0], inputs.params[1],
|
||||
inputs.params[2], inputs.params[3], inputs.params[4],
|
||||
inputs.params[5], inputs.params[6], inputs.params[7],
|
||||
inputs.params[8], inputs.params[9], inputs.params[10],
|
||||
inputs.params[11], inputs.params[12], inputs.params[13],
|
||||
inputs.params[14], inputs.params[15], inputs.params[16]);
|
||||
break;
|
||||
case 18:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8],
|
||||
inputs.params[9], inputs.params[10], inputs.params[11],
|
||||
inputs.params[12], inputs.params[13], inputs.params[14],
|
||||
inputs.params[15], inputs.params[16], inputs.params[17], output1,
|
||||
output2, output3);
|
||||
wfn(output1, output2, output3, inputs.params[0], inputs.params[1],
|
||||
inputs.params[2], inputs.params[3], inputs.params[4],
|
||||
inputs.params[5], inputs.params[6], inputs.params[7],
|
||||
inputs.params[8], inputs.params[9], inputs.params[10],
|
||||
inputs.params[11], inputs.params[12], inputs.params[13],
|
||||
inputs.params[14], inputs.params[15], inputs.params[16],
|
||||
inputs.params[17], output1, output2, output3);
|
||||
break;
|
||||
case 19:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8],
|
||||
inputs.params[9], inputs.params[10], inputs.params[11],
|
||||
inputs.params[12], inputs.params[13], inputs.params[14],
|
||||
inputs.params[15], inputs.params[16], inputs.params[17],
|
||||
inputs.params[18], output1, output2, output3);
|
||||
wfn(output1, output2, output3, inputs.params[0], inputs.params[1],
|
||||
inputs.params[2], inputs.params[3], inputs.params[4],
|
||||
inputs.params[5], inputs.params[6], inputs.params[7],
|
||||
inputs.params[8], inputs.params[9], inputs.params[10],
|
||||
inputs.params[11], inputs.params[12], inputs.params[13],
|
||||
inputs.params[14], inputs.params[15], inputs.params[16],
|
||||
inputs.params[17], inputs.params[18]);
|
||||
break;
|
||||
case 20:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8],
|
||||
inputs.params[9], inputs.params[10], inputs.params[11],
|
||||
inputs.params[12], inputs.params[13], inputs.params[14],
|
||||
inputs.params[15], inputs.params[16], inputs.params[17],
|
||||
inputs.params[18], inputs.params[19], output1, output2, output3);
|
||||
wfn(output1, output2, output3, inputs.params[0], inputs.params[1],
|
||||
inputs.params[2], inputs.params[3], inputs.params[4],
|
||||
inputs.params[5], inputs.params[6], inputs.params[7],
|
||||
inputs.params[8], inputs.params[9], inputs.params[10],
|
||||
inputs.params[11], inputs.params[12], inputs.params[13],
|
||||
inputs.params[14], inputs.params[15], inputs.params[16],
|
||||
inputs.params[17], inputs.params[18], inputs.params[19]);
|
||||
break;
|
||||
default:
|
||||
HPX_THROW_EXCEPTION(hpx::no_success,
|
||||
@@ -693,12 +683,10 @@ struct GenericComputeServer : component_base<GenericComputeServer> {
|
||||
// Deallocate input data buffers from OID deserialization (load)
|
||||
if (!_dfr_is_root_node()) {
|
||||
for (size_t p = 0; p < inputs.param_sizes.size(); ++p) {
|
||||
if (_dfr_get_arg_type(inputs.param_types[p]) != _DFR_TASK_ARG_CONTEXT) {
|
||||
if (_dfr_get_arg_type(inputs.param_types[p]) == _DFR_TASK_ARG_MEMREF)
|
||||
delete (static_cast<StridedMemRefType<char, 1> *>(inputs.params[p])
|
||||
->data);
|
||||
delete ((char *)inputs.params[p]);
|
||||
}
|
||||
if (_dfr_get_arg_type(inputs.param_types[p]) == _DFR_TASK_ARG_MEMREF)
|
||||
delete (static_cast<StridedMemRefType<char, 1> *>(inputs.params[p])
|
||||
->data);
|
||||
delete ((char *)inputs.params[p]);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -180,7 +180,7 @@ struct RuntimeContextManager {
|
||||
}
|
||||
}
|
||||
|
||||
RuntimeContext **getContext() { return &context; }
|
||||
RuntimeContext *getContext() { return context; }
|
||||
|
||||
void clearContext() {
|
||||
if (context != nullptr)
|
||||
|
||||
@@ -15,7 +15,7 @@ extern "C" {
|
||||
typedef void (*wfnptr)(...);
|
||||
|
||||
void *_dfr_make_ready_future(void *, size_t);
|
||||
void _dfr_create_async_task(wfnptr, size_t, size_t, ...);
|
||||
void _dfr_create_async_task(wfnptr, void *, size_t, size_t, ...);
|
||||
void _dfr_register_work_function(wfnptr);
|
||||
void *_dfr_await_future(void *);
|
||||
|
||||
@@ -26,8 +26,7 @@ void _dfr_deallocate_future(void *);
|
||||
void _dfr_deallocate_future_data(void *);
|
||||
|
||||
/* Initialisation & termination. */
|
||||
void _dfr_start_c(int64_t, void *);
|
||||
void _dfr_start(int64_t);
|
||||
void _dfr_start(int64_t, void *);
|
||||
void _dfr_stop(int64_t);
|
||||
|
||||
void _dfr_terminate();
|
||||
|
||||
@@ -25,6 +25,7 @@
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "concretelang/Conversion/Passes.h"
|
||||
#include "concretelang/Conversion/Utils/FuncConstOpConversion.h"
|
||||
#include "concretelang/Conversion/Utils/RegionOpTypeConverterPattern.h"
|
||||
#include "concretelang/Conversion/Utils/TensorOpTypeConversion.h"
|
||||
#include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h"
|
||||
@@ -108,6 +109,16 @@ public:
|
||||
newShape, mlir::IntegerType::get(type.getContext(), 64));
|
||||
return r;
|
||||
});
|
||||
addConversion([&](mlir::concretelang::RT::FutureType type) {
|
||||
return mlir::concretelang::RT::FutureType::get(
|
||||
this->convertType(type.dyn_cast<mlir::concretelang::RT::FutureType>()
|
||||
.getElementType()));
|
||||
});
|
||||
addConversion([&](mlir::concretelang::RT::PointerType type) {
|
||||
return mlir::concretelang::RT::PointerType::get(
|
||||
this->convertType(type.dyn_cast<mlir::concretelang::RT::PointerType>()
|
||||
.getElementType()));
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
@@ -956,6 +967,14 @@ void ConcreteToBConcretePass::runOnOperation() {
|
||||
return converter.isSignatureLegal(funcOp.getFunctionType()) &&
|
||||
converter.isLegal(&funcOp.getBody());
|
||||
});
|
||||
target.addDynamicallyLegalOp<mlir::func::ConstantOp>(
|
||||
[&](mlir::func::ConstantOp op) {
|
||||
return FunctionConstantOpConversion<
|
||||
ConcreteToBConcreteTypeConverter>::isLegal(op, converter);
|
||||
});
|
||||
patterns
|
||||
.insert<FunctionConstantOpConversion<ConcreteToBConcreteTypeConverter>>(
|
||||
&getContext(), converter);
|
||||
|
||||
target.addDynamicallyLegalOp<mlir::scf::ForOp>([&](mlir::scf::ForOp forOp) {
|
||||
return converter.isLegal(forOp.getInitArgs().getTypes()) &&
|
||||
@@ -969,19 +988,44 @@ void ConcreteToBConcretePass::runOnOperation() {
|
||||
converter.isLegal(op->getOperandTypes());
|
||||
});
|
||||
|
||||
// Conversion of RT Dialect Ops
|
||||
patterns.add<
|
||||
mlir::concretelang::GenericTypeConverterPattern<mlir::func::ReturnOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<mlir::scf::YieldOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::DataflowTaskOp>,
|
||||
mlir::concretelang::RT::MakeReadyFutureOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::DataflowYieldOp>>(&getContext(), converter);
|
||||
|
||||
// Conversion of RT Dialect Ops
|
||||
mlir::concretelang::RT::AwaitFutureOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::CreateAsyncTaskOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::BuildReturnPtrPlaceholderOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::DerefReturnPtrPlaceholderOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::WorkFunctionReturnOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::RegisterTaskWorkFunctionOp>>(&getContext(),
|
||||
converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::DataflowTaskOp>(target, converter);
|
||||
mlir::concretelang::RT::MakeReadyFutureOp>(target, converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::DataflowYieldOp>(target, converter);
|
||||
mlir::concretelang::RT::AwaitFutureOp>(target, converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::CreateAsyncTaskOp>(target, converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::BuildReturnPtrPlaceholderOp>(target, converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>(
|
||||
target, converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::DerefReturnPtrPlaceholderOp>(target, converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::WorkFunctionReturnOp>(target, converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::RegisterTaskWorkFunctionOp>(target, converter);
|
||||
|
||||
// Apply conversion
|
||||
if (mlir::applyPartialConversion(op, target, std::move(patterns))
|
||||
|
||||
@@ -14,11 +14,14 @@
|
||||
|
||||
#include "concretelang/Conversion/FHEToTFHE/Patterns.h"
|
||||
#include "concretelang/Conversion/Passes.h"
|
||||
#include "concretelang/Conversion/Utils/FuncConstOpConversion.h"
|
||||
#include "concretelang/Conversion/Utils/RegionOpTypeConverterPattern.h"
|
||||
#include "concretelang/Conversion/Utils/TensorOpTypeConversion.h"
|
||||
#include "concretelang/Dialect/FHE/IR/FHEDialect.h"
|
||||
#include "concretelang/Dialect/FHE/IR/FHETypes.h"
|
||||
#include "concretelang/Dialect/RT/IR/RTDialect.h"
|
||||
#include "concretelang/Dialect/RT/IR/RTOps.h"
|
||||
#include "concretelang/Dialect/RT/IR/RTTypes.h"
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHEDialect.h"
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHETypes.h"
|
||||
|
||||
@@ -53,6 +56,16 @@ public:
|
||||
eint.getContext(), eint));
|
||||
return r;
|
||||
});
|
||||
addConversion([&](mlir::concretelang::RT::FutureType type) {
|
||||
return mlir::concretelang::RT::FutureType::get(
|
||||
this->convertType(type.dyn_cast<mlir::concretelang::RT::FutureType>()
|
||||
.getElementType()));
|
||||
});
|
||||
addConversion([&](mlir::concretelang::RT::PointerType type) {
|
||||
return mlir::concretelang::RT::PointerType::get(
|
||||
this->convertType(type.dyn_cast<mlir::concretelang::RT::PointerType>()
|
||||
.getElementType()));
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
@@ -269,6 +282,11 @@ struct FHEToTFHEPass : public FHEToTFHEBase<FHEToTFHEPass> {
|
||||
return converter.isSignatureLegal(funcOp.getFunctionType()) &&
|
||||
converter.isLegal(&funcOp.getBody());
|
||||
});
|
||||
target.addDynamicallyLegalOp<mlir::func::ConstantOp>(
|
||||
[&](mlir::func::ConstantOp op) {
|
||||
return FunctionConstantOpConversion<FHEToTFHETypeConverter>::isLegal(
|
||||
op, converter);
|
||||
});
|
||||
|
||||
// Add all patterns required to lower all ops from `FHE` to
|
||||
// `TFHE`
|
||||
@@ -292,6 +310,8 @@ struct FHEToTFHEPass : public FHEToTFHEBase<FHEToTFHEPass> {
|
||||
|
||||
patterns.add<SubEintOpPattern>(&getContext());
|
||||
patterns.add<SubEintIntOpPattern>(&getContext());
|
||||
patterns.add<FunctionConstantOpConversion<FHEToTFHETypeConverter>>(
|
||||
&getContext(), converter);
|
||||
|
||||
patterns.add<RegionOpTypeConverterPattern<mlir::linalg::GenericOp,
|
||||
FHEToTFHETypeConverter>>(
|
||||
@@ -319,16 +339,43 @@ struct FHEToTFHEPass : public FHEToTFHEBase<FHEToTFHEPass> {
|
||||
patterns, converter);
|
||||
|
||||
// Conversion of RT Dialect Ops
|
||||
patterns.add<mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::DataflowTaskOp>>(patterns.getContext(),
|
||||
converter);
|
||||
patterns.add<
|
||||
mlir::concretelang::GenericTypeConverterPattern<mlir::func::ReturnOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<mlir::scf::YieldOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::MakeReadyFutureOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::AwaitFutureOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::CreateAsyncTaskOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::BuildReturnPtrPlaceholderOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::DerefReturnPtrPlaceholderOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::WorkFunctionReturnOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::RegisterTaskWorkFunctionOp>>(&getContext(),
|
||||
converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::DataflowTaskOp>(target, converter);
|
||||
patterns.add<mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::DataflowYieldOp>>(patterns.getContext(),
|
||||
converter);
|
||||
mlir::concretelang::RT::MakeReadyFutureOp>(target, converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::DataflowYieldOp>(target, converter);
|
||||
mlir::concretelang::RT::AwaitFutureOp>(target, converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::CreateAsyncTaskOp>(target, converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::BuildReturnPtrPlaceholderOp>(target, converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>(
|
||||
target, converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::DerefReturnPtrPlaceholderOp>(target, converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::WorkFunctionReturnOp>(target, converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::RegisterTaskWorkFunctionOp>(target, converter);
|
||||
|
||||
// Apply conversion
|
||||
if (mlir::applyPartialConversion(op, target, std::move(patterns))
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "concretelang/Conversion/Passes.h"
|
||||
#include "concretelang/Conversion/Utils/FuncConstOpConversion.h"
|
||||
#include "concretelang/Conversion/Utils/GenericOpTypeConversionPattern.h"
|
||||
#include "concretelang/Conversion/Utils/RegionOpTypeConverterPattern.h"
|
||||
#include "concretelang/Conversion/Utils/TensorOpTypeConversion.h"
|
||||
@@ -52,6 +53,16 @@ public:
|
||||
this->glweInterPBSType(glwe));
|
||||
return r;
|
||||
});
|
||||
addConversion([&](mlir::concretelang::RT::FutureType type) {
|
||||
return mlir::concretelang::RT::FutureType::get(
|
||||
this->convertType(type.dyn_cast<mlir::concretelang::RT::FutureType>()
|
||||
.getElementType()));
|
||||
});
|
||||
addConversion([&](mlir::concretelang::RT::PointerType type) {
|
||||
return mlir::concretelang::RT::PointerType::get(
|
||||
this->convertType(type.dyn_cast<mlir::concretelang::RT::PointerType>()
|
||||
.getElementType()));
|
||||
});
|
||||
}
|
||||
|
||||
TFHE::GLWECipherTextType glweInterPBSType(GLWECipherTextType &type) {
|
||||
@@ -293,6 +304,14 @@ void TFHEGlobalParametrizationPass::runOnOperation() {
|
||||
return converter.isSignatureLegal(funcOp.getFunctionType()) &&
|
||||
converter.isLegal(&funcOp.getBody());
|
||||
});
|
||||
target.addDynamicallyLegalOp<mlir::func::ConstantOp>(
|
||||
[&](mlir::func::ConstantOp op) {
|
||||
return FunctionConstantOpConversion<
|
||||
TFHEGlobalParametrizationTypeConverter>::isLegal(op, converter);
|
||||
});
|
||||
patterns.add<
|
||||
FunctionConstantOpConversion<TFHEGlobalParametrizationTypeConverter>>(
|
||||
&getContext(), converter);
|
||||
mlir::populateFunctionOpInterfaceTypeConversionPattern<mlir::func::FuncOp>(
|
||||
patterns, converter);
|
||||
|
||||
@@ -354,16 +373,43 @@ void TFHEGlobalParametrizationPass::runOnOperation() {
|
||||
patterns, target, converter);
|
||||
|
||||
// Conversion of RT Dialect Ops
|
||||
patterns.add<mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::DataflowTaskOp>>(patterns.getContext(),
|
||||
converter);
|
||||
patterns.add<
|
||||
mlir::concretelang::GenericTypeConverterPattern<mlir::func::ReturnOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<mlir::scf::YieldOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::MakeReadyFutureOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::AwaitFutureOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::CreateAsyncTaskOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::BuildReturnPtrPlaceholderOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::DerefReturnPtrPlaceholderOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::WorkFunctionReturnOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::RegisterTaskWorkFunctionOp>>(&getContext(),
|
||||
converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::DataflowTaskOp>(target, converter);
|
||||
patterns.add<mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::DataflowYieldOp>>(patterns.getContext(),
|
||||
converter);
|
||||
mlir::concretelang::RT::MakeReadyFutureOp>(target, converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::DataflowYieldOp>(target, converter);
|
||||
mlir::concretelang::RT::AwaitFutureOp>(target, converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::CreateAsyncTaskOp>(target, converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::BuildReturnPtrPlaceholderOp>(target, converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>(
|
||||
target, converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::DerefReturnPtrPlaceholderOp>(target, converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::WorkFunctionReturnOp>(target, converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::RegisterTaskWorkFunctionOp>(target, converter);
|
||||
|
||||
// Apply conversion
|
||||
if (mlir::applyPartialConversion(op, target, std::move(patterns))
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
|
||||
#include "concretelang/Conversion/Passes.h"
|
||||
#include "concretelang/Conversion/TFHEToConcrete/Patterns.h"
|
||||
#include "concretelang/Conversion/Utils/FuncConstOpConversion.h"
|
||||
#include "concretelang/Conversion/Utils/RegionOpTypeConverterPattern.h"
|
||||
#include "concretelang/Conversion/Utils/TensorOpTypeConversion.h"
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h"
|
||||
@@ -50,6 +51,16 @@ public:
|
||||
mlir::concretelang::convertTypeToLWE(glwe.getContext(), glwe));
|
||||
return r;
|
||||
});
|
||||
addConversion([&](mlir::concretelang::RT::FutureType type) {
|
||||
return mlir::concretelang::RT::FutureType::get(
|
||||
this->convertType(type.dyn_cast<mlir::concretelang::RT::FutureType>()
|
||||
.getElementType()));
|
||||
});
|
||||
addConversion([&](mlir::concretelang::RT::PointerType type) {
|
||||
return mlir::concretelang::RT::PointerType::get(
|
||||
this->convertType(type.dyn_cast<mlir::concretelang::RT::PointerType>()
|
||||
.getElementType()));
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
@@ -174,10 +185,17 @@ void TFHEToConcretePass::runOnOperation() {
|
||||
return converter.isSignatureLegal(funcOp.getFunctionType()) &&
|
||||
converter.isLegal(&funcOp.getBody());
|
||||
});
|
||||
target.addDynamicallyLegalOp<mlir::func::ConstantOp>(
|
||||
[&](mlir::func::ConstantOp op) {
|
||||
return FunctionConstantOpConversion<
|
||||
TFHEToConcreteTypeConverter>::isLegal(op, converter);
|
||||
});
|
||||
// Add all patterns required to lower all ops from `TFHE` to
|
||||
// `Concrete`
|
||||
mlir::RewritePatternSet patterns(&getContext());
|
||||
|
||||
patterns.add<FunctionConstantOpConversion<TFHEToConcreteTypeConverter>>(
|
||||
&getContext(), converter);
|
||||
populateWithGeneratedTFHEToConcrete(patterns);
|
||||
|
||||
patterns.add<mlir::concretelang::GenericTypeAndOpConverterPattern<
|
||||
@@ -219,16 +237,43 @@ void TFHEToConcretePass::runOnOperation() {
|
||||
patterns, converter);
|
||||
|
||||
// Conversion of RT Dialect Ops
|
||||
patterns.add<mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::DataflowTaskOp>>(patterns.getContext(),
|
||||
converter);
|
||||
patterns.add<
|
||||
mlir::concretelang::GenericTypeConverterPattern<mlir::func::ReturnOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<mlir::scf::YieldOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::MakeReadyFutureOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::AwaitFutureOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::CreateAsyncTaskOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::BuildReturnPtrPlaceholderOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::DerefReturnPtrPlaceholderOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::WorkFunctionReturnOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::RegisterTaskWorkFunctionOp>>(&getContext(),
|
||||
converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::DataflowTaskOp>(target, converter);
|
||||
patterns.add<mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::DataflowYieldOp>>(patterns.getContext(),
|
||||
converter);
|
||||
mlir::concretelang::RT::MakeReadyFutureOp>(target, converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::DataflowYieldOp>(target, converter);
|
||||
mlir::concretelang::RT::AwaitFutureOp>(target, converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::CreateAsyncTaskOp>(target, converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::BuildReturnPtrPlaceholderOp>(target, converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>(
|
||||
target, converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::DerefReturnPtrPlaceholderOp>(target, converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::WorkFunctionReturnOp>(target, converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::RegisterTaskWorkFunctionOp>(target, converter);
|
||||
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<mlir::func::ReturnOp>(
|
||||
target, converter);
|
||||
|
||||
@@ -32,28 +32,13 @@ struct AddRuntimeContextToFuncOpPattern
|
||||
rewriter.getType<mlir::concretelang::Concrete::ContextType>());
|
||||
mlir::FunctionType newFuncTy = rewriter.getType<mlir::FunctionType>(
|
||||
newInputs, oldFuncType.getResults());
|
||||
// Create the new func
|
||||
mlir::func::FuncOp newFuncOp = rewriter.create<mlir::func::FuncOp>(
|
||||
oldFuncOp.getLoc(), oldFuncOp.getName(), newFuncTy);
|
||||
|
||||
// Create the arguments of the new func
|
||||
mlir::Region &newFuncBody = newFuncOp.getBody();
|
||||
mlir::Block *newFuncEntryBlock = new mlir::Block();
|
||||
llvm::SmallVector<mlir::Location> locations(newFuncTy.getInputs().size(),
|
||||
oldFuncOp.getLoc());
|
||||
rewriter.updateRootInPlace(oldFuncOp,
|
||||
[&] { oldFuncOp.setType(newFuncTy); });
|
||||
oldFuncOp.getBody().front().addArgument(
|
||||
rewriter.getType<mlir::concretelang::Concrete::ContextType>(),
|
||||
oldFuncOp.getLoc());
|
||||
|
||||
newFuncEntryBlock->addArguments(newFuncTy.getInputs(), locations);
|
||||
newFuncBody.push_back(newFuncEntryBlock);
|
||||
|
||||
// Clone the old body to the new one
|
||||
mlir::BlockAndValueMapping map;
|
||||
for (auto arg : llvm::enumerate(oldFuncOp.getArguments())) {
|
||||
map.map(arg.value(), newFuncEntryBlock->getArgument(arg.index()));
|
||||
}
|
||||
for (auto &op : oldFuncOp.getBody().front()) {
|
||||
newFuncEntryBlock->push_back(op.clone(map));
|
||||
}
|
||||
rewriter.eraseOp(oldFuncOp);
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
@@ -72,6 +57,50 @@ struct AddRuntimeContextToFuncOpPattern
|
||||
}
|
||||
};
|
||||
|
||||
namespace {
|
||||
struct FunctionConstantOpConversion
|
||||
: public mlir::OpRewritePattern<mlir::func::ConstantOp> {
|
||||
FunctionConstantOpConversion(mlir::MLIRContext *ctx,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: ::mlir::OpRewritePattern<mlir::func::ConstantOp>(ctx, benefit) {}
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(mlir::func::ConstantOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto symTab = mlir::SymbolTable::getNearestSymbolTable(op);
|
||||
auto funcOp = mlir::SymbolTable::lookupSymbolIn(symTab, op.getValue());
|
||||
assert(funcOp &&
|
||||
"Function symbol missing in symbol table for function constant op.");
|
||||
mlir::FunctionType funType = mlir::cast<mlir::func::FuncOp>(funcOp)
|
||||
.getFunctionType()
|
||||
.cast<mlir::FunctionType>();
|
||||
mlir::SmallVector<mlir::Type> newInputs(funType.getInputs().begin(),
|
||||
funType.getInputs().end());
|
||||
newInputs.push_back(
|
||||
rewriter.getType<mlir::concretelang::Concrete::ContextType>());
|
||||
mlir::FunctionType newFuncTy =
|
||||
rewriter.getType<mlir::FunctionType>(newInputs, funType.getResults());
|
||||
|
||||
rewriter.updateRootInPlace(op, [&] { op.getResult().setType(newFuncTy); });
|
||||
return mlir::success();
|
||||
}
|
||||
static bool isLegal(mlir::func::ConstantOp fun) {
|
||||
auto symTab = mlir::SymbolTable::getNearestSymbolTable(fun);
|
||||
auto funcOp = mlir::SymbolTable::lookupSymbolIn(symTab, fun.getValue());
|
||||
assert(funcOp &&
|
||||
"Function symbol missing in symbol table for function constant op.");
|
||||
mlir::FunctionType funType = mlir::cast<mlir::func::FuncOp>(funcOp)
|
||||
.getFunctionType()
|
||||
.cast<mlir::FunctionType>();
|
||||
if ((AddRuntimeContextToFuncOpPattern::isLegal(
|
||||
mlir::cast<mlir::func::FuncOp>(funcOp)) &&
|
||||
fun.getType() == funType) ||
|
||||
fun.getType() != funType)
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
struct AddRuntimeContextPass
|
||||
: public AddRuntimeContextBase<AddRuntimeContextPass> {
|
||||
void runOnOperation() final;
|
||||
@@ -90,8 +119,13 @@ void AddRuntimeContextPass::runOnOperation() {
|
||||
[&](mlir::func::FuncOp funcOp) {
|
||||
return AddRuntimeContextToFuncOpPattern::isLegal(funcOp);
|
||||
});
|
||||
target.addDynamicallyLegalOp<mlir::func::ConstantOp>(
|
||||
[&](mlir::func::ConstantOp op) {
|
||||
return FunctionConstantOpConversion::isLegal(op);
|
||||
});
|
||||
|
||||
patterns.add<AddRuntimeContextToFuncOpPattern>(patterns.getContext());
|
||||
patterns.add<FunctionConstantOpConversion>(patterns.getContext());
|
||||
|
||||
// Apply the conversion
|
||||
if (mlir::applyPartialConversion(op, target, std::move(patterns))
|
||||
|
||||
@@ -153,6 +153,9 @@ mlir::Value getContextArgument(mlir::Operation *op) {
|
||||
mlir::Block *block = op->getBlock();
|
||||
while (block != nullptr) {
|
||||
if (llvm::isa<mlir::func::FuncOp>(block->getParentOp())) {
|
||||
block = &mlir::cast<mlir::func::FuncOp>(block->getParentOp())
|
||||
.getBody()
|
||||
.front();
|
||||
|
||||
auto context =
|
||||
std::find_if(block->getArguments().rbegin(),
|
||||
@@ -160,7 +163,6 @@ mlir::Value getContextArgument(mlir::Operation *op) {
|
||||
return arg.getType()
|
||||
.isa<mlir::concretelang::Concrete::ContextType>();
|
||||
});
|
||||
|
||||
assert(context != block->getArguments().rend() &&
|
||||
"Cannot find the Concrete.context");
|
||||
|
||||
|
||||
@@ -10,11 +10,16 @@
|
||||
#include <concretelang/Dialect/RT/IR/RTOps.h>
|
||||
#include <concretelang/Dialect/RT/IR/RTTypes.h>
|
||||
|
||||
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
|
||||
#include "mlir/Dialect/Func/Transforms/Passes.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include <concretelang/Conversion/Utils/FuncConstOpConversion.h>
|
||||
#include <concretelang/Conversion/Utils/GenericOpTypeConversionPattern.h>
|
||||
#include <llvm/IR/Instructions.h>
|
||||
#include <mlir/Dialect/Bufferization/IR/Bufferization.h>
|
||||
#include <mlir/Dialect/Bufferization/Transforms/Bufferize.h>
|
||||
#include <mlir/Dialect/Bufferization/Transforms/Passes.h>
|
||||
#include <mlir/Dialect/Func/IR/FuncOps.h>
|
||||
#include <mlir/Dialect/MemRef/IR/MemRef.h>
|
||||
#include <mlir/IR/BlockAndValueMapping.h>
|
||||
#include <mlir/IR/Builders.h>
|
||||
@@ -27,68 +32,35 @@ namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
namespace {
|
||||
class BufferizeDataflowYieldOp
|
||||
: public OpConversionPattern<RT::DataflowYieldOp> {
|
||||
|
||||
class BufferizeRTTypesConverter
|
||||
: public mlir::bufferization::BufferizeTypeConverter {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(RT::DataflowYieldOp op, RT::DataflowYieldOp::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<RT::DataflowYieldOp>(op, mlir::TypeRange(),
|
||||
adaptor.getOperands());
|
||||
return success();
|
||||
BufferizeRTTypesConverter() {
|
||||
addConversion([&](mlir::concretelang::RT::FutureType type) {
|
||||
return mlir::concretelang::RT::FutureType::get(
|
||||
this->convertType(type.dyn_cast<mlir::concretelang::RT::FutureType>()
|
||||
.getElementType()));
|
||||
});
|
||||
addConversion([&](mlir::concretelang::RT::PointerType type) {
|
||||
return mlir::concretelang::RT::PointerType::get(
|
||||
this->convertType(type.dyn_cast<mlir::concretelang::RT::PointerType>()
|
||||
.getElementType()));
|
||||
});
|
||||
addConversion([&](mlir::FunctionType type) {
|
||||
SignatureConversion result(type.getNumInputs());
|
||||
mlir::SmallVector<mlir::Type, 1> newResults;
|
||||
if (failed(this->convertSignatureArgs(type.getInputs(), result)) ||
|
||||
failed(this->convertTypes(type.getResults(), newResults)))
|
||||
return type;
|
||||
return mlir::FunctionType::get(type.getContext(),
|
||||
result.getConvertedTypes(), newResults);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class BufferizeDataflowTaskOp : public OpConversionPattern<RT::DataflowTaskOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(RT::DataflowTaskOp op, RT::DataflowTaskOp::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
mlir::OpBuilder::InsertionGuard guard(rewriter);
|
||||
|
||||
SmallVector<Type> newResults;
|
||||
(void)getTypeConverter()->convertTypes(op.getResultTypes(), newResults);
|
||||
auto newop = rewriter.create<RT::DataflowTaskOp>(op.getLoc(), newResults,
|
||||
adaptor.getOperands());
|
||||
// We cannot clone here as cloned ops must be legalized (so this
|
||||
// would break on the YieldOp). Instead use mergeBlocks which
|
||||
// moves the ops instead of cloning.
|
||||
rewriter.mergeBlocks(op.getBody(), newop.getBody(),
|
||||
newop.getBody()->getArguments());
|
||||
// Because of previous bufferization there are buffer cast ops
|
||||
// that have been generated for the previously tensor results of
|
||||
// some tasks. These cannot just be replaced directly as the
|
||||
// task's results would still be live.
|
||||
for (auto res : llvm::enumerate(op.getResults())) {
|
||||
// If this result is getting bufferized ...
|
||||
if (res.value().getType() !=
|
||||
getTypeConverter()->convertType(res.value().getType())) {
|
||||
for (auto &use : llvm::make_early_inc_range(res.value().getUses())) {
|
||||
// ... and its uses are in `ToMemrefOp`s, then we
|
||||
// replace further uses of the buffer cast.
|
||||
if (isa<mlir::bufferization::ToMemrefOp>(use.getOwner())) {
|
||||
rewriter.replaceOp(use.getOwner(), {newop.getResult(res.index())});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
rewriter.replaceOp(op, {newop.getResults()});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void populateRTBufferizePatterns(
|
||||
mlir::bufferization::BufferizeTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<BufferizeDataflowYieldOp, BufferizeDataflowTaskOp>(
|
||||
typeConverter, patterns.getContext());
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// For documentation see Autopar.td
|
||||
struct BufferizeDataflowTaskOpsPass
|
||||
@@ -97,15 +69,75 @@ struct BufferizeDataflowTaskOpsPass
|
||||
void runOnOperation() override {
|
||||
auto module = getOperation();
|
||||
auto *context = &getContext();
|
||||
mlir::bufferization::BufferizeTypeConverter typeConverter;
|
||||
BufferizeRTTypesConverter typeConverter;
|
||||
|
||||
RewritePatternSet patterns(context);
|
||||
ConversionTarget target(*context);
|
||||
|
||||
populateRTBufferizePatterns(typeConverter, patterns);
|
||||
populateFunctionOpInterfaceTypeConversionPattern<mlir::func::FuncOp>(
|
||||
patterns, typeConverter);
|
||||
patterns.add<FunctionConstantOpConversion<BufferizeRTTypesConverter>>(
|
||||
context, typeConverter);
|
||||
|
||||
// Forbid all RT ops that still use/return tensors
|
||||
target.addDynamicallyLegalDialect<RT::RTDialect>(
|
||||
[&](Operation *op) { return typeConverter.isLegal(op); });
|
||||
target.addDynamicallyLegalDialect<mlir::func::FuncDialect>([&](Operation
|
||||
*op) {
|
||||
if (auto fun = dyn_cast_or_null<mlir::func::FuncOp>(op))
|
||||
return typeConverter.isSignatureLegal(fun.getFunctionType()) &&
|
||||
typeConverter.isLegal(&fun.getBody());
|
||||
if (auto fun = dyn_cast_or_null<mlir::func::ConstantOp>(op))
|
||||
return FunctionConstantOpConversion<BufferizeRTTypesConverter>::isLegal(
|
||||
fun, typeConverter);
|
||||
return typeConverter.isLegal(op);
|
||||
});
|
||||
|
||||
patterns.add<
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::DataflowTaskOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::DataflowYieldOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::MakeReadyFutureOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::AwaitFutureOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::CreateAsyncTaskOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::BuildReturnPtrPlaceholderOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::DerefReturnPtrPlaceholderOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::WorkFunctionReturnOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::RegisterTaskWorkFunctionOp>>(&getContext(),
|
||||
typeConverter);
|
||||
|
||||
// Conversion of RT Dialect Ops
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::DataflowTaskOp>(target, typeConverter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::DataflowYieldOp>(target, typeConverter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::MakeReadyFutureOp>(target, typeConverter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::AwaitFutureOp>(target, typeConverter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::CreateAsyncTaskOp>(target, typeConverter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::BuildReturnPtrPlaceholderOp>(target,
|
||||
typeConverter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>(
|
||||
target, typeConverter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::DerefReturnPtrPlaceholderOp>(target,
|
||||
typeConverter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::WorkFunctionReturnOp>(target, typeConverter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::RegisterTaskWorkFunctionOp>(target,
|
||||
typeConverter);
|
||||
|
||||
if (failed(applyPartialConversion(module, target, std::move(patterns))))
|
||||
signalPassFailure();
|
||||
|
||||
@@ -18,7 +18,6 @@
|
||||
|
||||
#include <mlir/Dialect/Arithmetic/IR/Arithmetic.h>
|
||||
#include <mlir/Dialect/Func/IR/FuncOps.h>
|
||||
#include <mlir/Dialect/MemRef/IR/MemRef.h>
|
||||
#include <mlir/IR/Attributes.h>
|
||||
#include <mlir/IR/BlockAndValueMapping.h>
|
||||
#include <mlir/IR/Builders.h>
|
||||
@@ -61,9 +60,8 @@ static bool isAggregatingBeneficiary(Operation *op) {
|
||||
return isa<FHE::ZeroEintOp, FHE::ZeroTensorOp, FHE::AddEintIntOp,
|
||||
FHE::AddEintOp, FHE::SubIntEintOp, FHE::SubEintIntOp,
|
||||
FHE::MulEintIntOp, FHE::SubEintOp, FHE::NegEintOp,
|
||||
FHELinalg::FromElementOp, arith::ConstantOp, memref::DimOp,
|
||||
arith::SelectOp, mlir::arith::CmpIOp, memref::GetGlobalOp,
|
||||
memref::CastOp>(op);
|
||||
FHELinalg::FromElementOp, arith::ConstantOp, arith::SelectOp,
|
||||
mlir::arith::CmpIOp>(op);
|
||||
}
|
||||
|
||||
static bool
|
||||
@@ -95,87 +93,6 @@ aggregateBeneficiaryOps(Operation *op, SetVector<Operation *> &beneficiaryOps,
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool isFunctionCallName(OpOperand *use, StringRef name) {
|
||||
func::CallOp call = dyn_cast_or_null<mlir::func::CallOp>(use->getOwner());
|
||||
if (!call)
|
||||
return false;
|
||||
SymbolRefAttr sym = call.getCallableForCallee().dyn_cast<SymbolRefAttr>();
|
||||
if (!sym)
|
||||
return false;
|
||||
func::FuncOp called = dyn_cast_or_null<func::FuncOp>(
|
||||
SymbolTable::lookupNearestSymbolFrom(call, sym));
|
||||
if (!called)
|
||||
return false;
|
||||
return called.getName() == name;
|
||||
}
|
||||
|
||||
static void getAliasedUses(Value val, DenseSet<OpOperand *> &aliasedUses) {
|
||||
for (auto &use : val.getUses()) {
|
||||
aliasedUses.insert(&use);
|
||||
if (dyn_cast<ViewLikeOpInterface>(use.getOwner()))
|
||||
getAliasedUses(use.getOwner()->getResult(0), aliasedUses);
|
||||
}
|
||||
}
|
||||
|
||||
static bool aggregateOutputMemrefAllocations(
|
||||
Operation *op, SetVector<Operation *> &beneficiaryOps,
|
||||
llvm::SmallPtrSetImpl<Value> &availableValues, RT::DataflowTaskOp taskOp) {
|
||||
if (beneficiaryOps.count(op))
|
||||
return true;
|
||||
|
||||
if (!isa<mlir::memref::AllocOp>(op))
|
||||
return false;
|
||||
|
||||
Value val = op->getResults().front();
|
||||
DenseSet<OpOperand *> aliasedUses;
|
||||
getAliasedUses(val, aliasedUses);
|
||||
|
||||
// Helper function checking if a memref use writes to memory
|
||||
auto hasMemoryWriteEffect = [&](OpOperand *use) {
|
||||
// Call ops targeting concrete-ffi do not have memory effects
|
||||
// interface, so handle apart.
|
||||
// TODO: this could be handled better in BConcrete or higher.
|
||||
if (isFunctionCallName(use, "memref_expand_lut_in_trivial_glwe_ct_u64") ||
|
||||
isFunctionCallName(use, "memref_add_lwe_ciphertexts_u64") ||
|
||||
isFunctionCallName(use, "memref_add_plaintext_lwe_ciphertext_u64") ||
|
||||
isFunctionCallName(use, "memref_mul_cleartext_lwe_ciphertext_u64") ||
|
||||
isFunctionCallName(use, "memref_negate_lwe_ciphertext_u64") ||
|
||||
isFunctionCallName(use, "memref_keyswitch_lwe_u64") ||
|
||||
isFunctionCallName(use, "memref_bootstrap_lwe_u64"))
|
||||
if (use->getOwner()->getOperand(0) == use->get())
|
||||
return true;
|
||||
|
||||
if (isFunctionCallName(use, "memref_copy_one_rank"))
|
||||
if (use->getOwner()->getOperand(1) == use->get())
|
||||
return true;
|
||||
|
||||
// Otherwise we rely on the memory effect interface
|
||||
auto effectInterface = dyn_cast<MemoryEffectOpInterface>(use->getOwner());
|
||||
if (!effectInterface)
|
||||
return false;
|
||||
SmallVector<MemoryEffects::EffectInstance, 2> effects;
|
||||
effectInterface.getEffects(effects);
|
||||
for (auto eff : effects)
|
||||
if (isa<MemoryEffects::Write>(eff.getEffect()) &&
|
||||
eff.getValue() == use->get())
|
||||
return true;
|
||||
return false;
|
||||
};
|
||||
|
||||
// We need to check if this allocated memref is written in this task.
|
||||
// TODO: for now we'll assume that we don't do partial writes or read/writes.
|
||||
for (auto use : aliasedUses)
|
||||
if (hasMemoryWriteEffect(use) &&
|
||||
use->getOwner()->getParentOfType<RT::DataflowTaskOp>() == taskOp) {
|
||||
// We will sink the operation, mark its results as now available.
|
||||
beneficiaryOps.insert(op);
|
||||
for (Value result : op->getResults())
|
||||
availableValues.insert(result);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
LogicalResult coarsenDFTask(RT::DataflowTaskOp taskOp) {
|
||||
Region &taskOpBody = taskOp.body();
|
||||
|
||||
@@ -191,8 +108,6 @@ LogicalResult coarsenDFTask(RT::DataflowTaskOp taskOp) {
|
||||
if (!operandOp)
|
||||
continue;
|
||||
aggregateBeneficiaryOps(operandOp, toBeSunk, availableValues);
|
||||
aggregateOutputMemrefAllocations(operandOp, toBeSunk, availableValues,
|
||||
taskOp);
|
||||
}
|
||||
|
||||
// Insert operations so that the defs get cloned before uses.
|
||||
@@ -283,45 +198,5 @@ std::unique_ptr<mlir::Pass> createBuildDataflowTaskGraphPass(bool debug) {
|
||||
return std::make_unique<BuildDataflowTaskGraphPass>(debug);
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// For documentation see Autopar.td
|
||||
struct FixupDataflowTaskOpsPass
|
||||
: public FixupDataflowTaskOpsBase<FixupDataflowTaskOpsPass> {
|
||||
|
||||
void runOnOperation() override {
|
||||
auto module = getOperation();
|
||||
|
||||
module->walk([](RT::DataflowTaskOp op) {
|
||||
assert(!failed(coarsenDFTask(op)) &&
|
||||
"Failing to sink operations into DFT");
|
||||
});
|
||||
|
||||
// Finally clear up any remaining alloc/dealloc ops that are
|
||||
// meaningless
|
||||
SetVector<Operation *> eraseOps;
|
||||
module->walk([&](memref::AllocOp op) {
|
||||
// If this memref.alloc's only use left is the
|
||||
// dealloc, erase both.
|
||||
if (op->hasOneUse() &&
|
||||
isa<mlir::memref::DeallocOp>(op->use_begin()->getOwner())) {
|
||||
eraseOps.insert(op->use_begin()->getOwner());
|
||||
eraseOps.insert(op);
|
||||
}
|
||||
});
|
||||
for (auto op : eraseOps)
|
||||
op->erase();
|
||||
}
|
||||
|
||||
FixupDataflowTaskOpsPass(bool debug) : debug(debug){};
|
||||
|
||||
protected:
|
||||
bool debug;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
std::unique_ptr<mlir::Pass> createFixupDataflowTaskOpsPass(bool debug) {
|
||||
return std::make_unique<FixupDataflowTaskOpsPass>(debug);
|
||||
}
|
||||
|
||||
} // end namespace concretelang
|
||||
} // end namespace mlir
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include <iostream>
|
||||
|
||||
#include <concretelang/Conversion/Tools.h>
|
||||
#include <concretelang/Conversion/Utils/GenericOpTypeConversionPattern.h>
|
||||
#include <concretelang/Dialect/Concrete/IR/ConcreteDialect.h>
|
||||
#include <concretelang/Dialect/Concrete/IR/ConcreteOps.h>
|
||||
#include <concretelang/Dialect/Concrete/IR/ConcreteTypes.h>
|
||||
@@ -18,9 +19,7 @@
|
||||
#include <concretelang/Dialect/RT/IR/RTTypes.h>
|
||||
#include <concretelang/Runtime/DFRuntime.hpp>
|
||||
#include <concretelang/Support/math.h>
|
||||
#include <mlir/IR/BuiltinOps.h>
|
||||
|
||||
#include <concretelang/Conversion/Utils/GenericOpTypeConversionPattern.h>
|
||||
#include <llvm/IR/Instructions.h>
|
||||
#include <mlir/Analysis/DataFlowAnalysis.h>
|
||||
#include <mlir/Conversion/LLVMCommon/ConversionTarget.h>
|
||||
@@ -39,7 +38,9 @@
|
||||
#include <mlir/IR/BlockAndValueMapping.h>
|
||||
#include <mlir/IR/Builders.h>
|
||||
#include <mlir/IR/BuiltinAttributes.h>
|
||||
#include <mlir/IR/BuiltinOps.h>
|
||||
#include <mlir/IR/SymbolTable.h>
|
||||
#include <mlir/Interfaces/ViewLikeInterface.h>
|
||||
#include <mlir/Pass/PassManager.h>
|
||||
#include <mlir/Support/LLVM.h>
|
||||
#include <mlir/Support/LogicalResult.h>
|
||||
@@ -67,10 +68,10 @@ static func::FuncOp outlineWorkFunction(RT::DataflowTaskOp DFTOp,
|
||||
// types, which will be changed to use an indirection when lowering.
|
||||
SmallVector<Type, 4> operandTypes;
|
||||
operandTypes.reserve(DFTOp.getNumOperands() + DFTOp.getNumResults());
|
||||
for (Value operand : DFTOp.getOperands())
|
||||
operandTypes.push_back(RT::PointerType::get(operand.getType()));
|
||||
for (Value res : DFTOp.getResults())
|
||||
operandTypes.push_back(RT::PointerType::get(res.getType()));
|
||||
for (Value operand : DFTOp.getOperands())
|
||||
operandTypes.push_back(RT::PointerType::get(operand.getType()));
|
||||
|
||||
FunctionType type = FunctionType::get(DFTOp.getContext(), operandTypes, {});
|
||||
auto outlinedFunc = builder.create<func::FuncOp>(loc, workFunctionName, type);
|
||||
@@ -82,6 +83,7 @@ static func::FuncOp outlineWorkFunction(RT::DataflowTaskOp DFTOp,
|
||||
outlinedFuncBody.push_back(outlinedEntryBlock);
|
||||
|
||||
BlockAndValueMapping map;
|
||||
int input_offset = DFTOp.getNumResults();
|
||||
Block &entryBlock = outlinedFuncBody.front();
|
||||
builder.setInsertionPointToStart(&entryBlock);
|
||||
for (auto operand : llvm::enumerate(DFTOp.getOperands())) {
|
||||
@@ -89,7 +91,7 @@ static func::FuncOp outlineWorkFunction(RT::DataflowTaskOp DFTOp,
|
||||
auto derefdop =
|
||||
builder.create<RT::DerefWorkFunctionArgumentPtrPlaceholderOp>(
|
||||
DFTOp.getLoc(), operand.value().getType(),
|
||||
entryBlock.getArgument(operand.index()));
|
||||
entryBlock.getArgument(operand.index() + input_offset));
|
||||
map.map(operand.value(), derefdop->getResult(0));
|
||||
}
|
||||
DFTOpBody.cloneInto(&outlinedFuncBody, map);
|
||||
@@ -99,17 +101,14 @@ static func::FuncOp outlineWorkFunction(RT::DataflowTaskOp DFTOp,
|
||||
builder.setInsertionPointToEnd(&entryBlock);
|
||||
builder.create<cf::BranchOp>(loc, clonedDFTOpEntry);
|
||||
|
||||
// TODO: we use a WorkFunctionReturnOp to tie return to the
|
||||
// corresponding argument. This can be lowered to a copy/deref for
|
||||
// shared memory and pointers, but needs to be handled for
|
||||
// distributed memory.
|
||||
// WorkFunctionReturnOp ties return to the corresponding argument.
|
||||
// This is lowered to a copy/deref for shared memory and pointers,
|
||||
// and handled in the serializer for distributed memory.
|
||||
outlinedFunc.walk([&](RT::DataflowYieldOp op) {
|
||||
OpBuilder replacer(op);
|
||||
int output_offset = DFTOp.getNumOperands();
|
||||
for (auto ret : llvm::enumerate(op.getOperands()))
|
||||
replacer.create<RT::WorkFunctionReturnOp>(
|
||||
op.getLoc(), ret.value(),
|
||||
outlinedFunc.getArgument(ret.index() + output_offset));
|
||||
op.getLoc(), ret.value(), outlinedFunc.getArgument(ret.index()));
|
||||
replacer.create<func::ReturnOp>(op.getLoc());
|
||||
op.erase();
|
||||
});
|
||||
@@ -126,33 +125,37 @@ static void replaceAllUsesInDFTsInRegionWith(Value orig, Value replacement,
|
||||
}
|
||||
}
|
||||
|
||||
static mlir::Type stripType(mlir::Type type) {
|
||||
if (type.isa<RT::FutureType>())
|
||||
return stripType(type.dyn_cast<RT::FutureType>().getElementType());
|
||||
if (type.isa<RT::PointerType>())
|
||||
return stripType(type.dyn_cast<RT::PointerType>().getElementType());
|
||||
return type;
|
||||
}
|
||||
|
||||
// TODO: Fix type sizes. For now we're using some default values.
|
||||
static std::pair<mlir::Value, mlir::Value>
|
||||
static std::pair<Value, Value>
|
||||
getTaskArgumentSizeAndType(Value val, Location loc, OpBuilder builder) {
|
||||
DataLayout dataLayout = DataLayout::closest(val.getDefiningOp());
|
||||
Type type = (val.getType().isa<RT::FutureType>())
|
||||
? val.getType().dyn_cast<RT::FutureType>().getElementType()
|
||||
: val.getType();
|
||||
Type type = stripType(val.getType());
|
||||
|
||||
// In the case of memref, we need to determine how much space
|
||||
// (conservatively) we need to store the memref itself. Overshooting
|
||||
// by a few bytes should not be an issue, so the main thing is to
|
||||
// properly account for the rank.
|
||||
if (type.isa<mlir::MemRefType>()) {
|
||||
// Space for the allocated and aligned pointers, and offset
|
||||
Value ptrs_offset =
|
||||
builder.create<arith::ConstantOp>(loc, builder.getI64IntegerAttr(24));
|
||||
// For the sizes and shapes arrays, we need 2*8 = 16 times the rank in bytes
|
||||
Value multiplier =
|
||||
builder.create<arith::ConstantOp>(loc, builder.getI64IntegerAttr(16));
|
||||
unsigned _rank = type.dyn_cast<mlir::MemRefType>().getRank();
|
||||
Value rank = builder.create<arith::ConstantOp>(
|
||||
loc, builder.getI64IntegerAttr(_rank));
|
||||
Value sizes_shapes = builder.create<LLVM::MulOp>(loc, rank, multiplier);
|
||||
Value typeSize =
|
||||
builder.create<LLVM::AddOp>(loc, ptrs_offset, sizes_shapes);
|
||||
|
||||
// Space for the allocated and aligned pointers, and offset plus
|
||||
// rank * sizes and strides
|
||||
size_t element_size;
|
||||
unsigned rank = type.dyn_cast<mlir::MemRefType>().getRank();
|
||||
Type elementType = type.dyn_cast<mlir::MemRefType>().getElementType();
|
||||
|
||||
element_size = dataLayout.getTypeSize(elementType);
|
||||
|
||||
size_t size = 24 + 16 * rank;
|
||||
Value typeSize =
|
||||
builder.create<arith::ConstantOp>(loc, builder.getI64IntegerAttr(size));
|
||||
|
||||
// Assume here that the base type is a simple scalar-type or at
|
||||
// least its size can be determined.
|
||||
// size_t elementAttr = dataLayout.getTypeSize(elementType);
|
||||
@@ -160,7 +163,6 @@ getTaskArgumentSizeAndType(Value val, Location loc, OpBuilder builder) {
|
||||
// elementAttr <<= 8;
|
||||
// elementAttr |= _DFR_TASK_ARG_MEMREF;
|
||||
uint64_t elementAttr = 0;
|
||||
size_t element_size = dataLayout.getTypeSize(elementType);
|
||||
elementAttr =
|
||||
dfr::_dfr_set_arg_type(elementAttr, dfr::_DFR_TASK_ARG_MEMREF);
|
||||
elementAttr = dfr::_dfr_set_memref_element_size(elementAttr, element_size);
|
||||
@@ -169,34 +171,19 @@ getTaskArgumentSizeAndType(Value val, Location loc, OpBuilder builder) {
|
||||
return std::pair<mlir::Value, mlir::Value>(typeSize, arg_type);
|
||||
}
|
||||
|
||||
// Unranked memrefs should be lowered to just pointer + size, so we need 16
|
||||
// bytes.
|
||||
assert(!type.isa<mlir::UnrankedMemRefType>() &&
|
||||
"UnrankedMemRefType not currently supported");
|
||||
if (type.isa<mlir::concretelang::Concrete::ContextType>()) {
|
||||
Value arg_type = builder.create<arith::ConstantOp>(
|
||||
loc, builder.getI64IntegerAttr(dfr::_DFR_TASK_ARG_CONTEXT));
|
||||
Value typeSize =
|
||||
builder.create<arith::ConstantOp>(loc, builder.getI64IntegerAttr(8));
|
||||
return std::pair<mlir::Value, mlir::Value>(typeSize, arg_type);
|
||||
}
|
||||
|
||||
Value arg_type = builder.create<arith::ConstantOp>(
|
||||
loc, builder.getI64IntegerAttr(dfr::_DFR_TASK_ARG_BASE));
|
||||
|
||||
// FHE types are converted to pointers, so we take their size as 8
|
||||
// bytes until we can get the actual size of the actual types.
|
||||
if (type.isa<mlir::concretelang::Concrete::LweCiphertextType>() ||
|
||||
type.isa<mlir::concretelang::Concrete::GlweCiphertextType>() ||
|
||||
type.isa<mlir::concretelang::Concrete::PlaintextType>()) {
|
||||
Value result =
|
||||
builder.create<arith::ConstantOp>(loc, builder.getI64IntegerAttr(8));
|
||||
return std::pair<mlir::Value, mlir::Value>(result, arg_type);
|
||||
} else if (type.isa<mlir::concretelang::Concrete::ContextType>()) {
|
||||
Value arg_type = builder.create<arith::ConstantOp>(
|
||||
loc, builder.getI64IntegerAttr(dfr::_DFR_TASK_ARG_CONTEXT));
|
||||
Value result =
|
||||
builder.create<arith::ConstantOp>(loc, builder.getI64IntegerAttr(8));
|
||||
return std::pair<mlir::Value, mlir::Value>(result, arg_type);
|
||||
}
|
||||
|
||||
// For all other types, get type size.
|
||||
Value result = builder.create<arith::ConstantOp>(
|
||||
Value typeSize = builder.create<arith::ConstantOp>(
|
||||
loc, builder.getI64IntegerAttr(dataLayout.getTypeSize(type)));
|
||||
return std::pair<mlir::Value, mlir::Value>(result, arg_type);
|
||||
return std::pair<mlir::Value, mlir::Value>(typeSize, arg_type);
|
||||
}
|
||||
|
||||
static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp,
|
||||
@@ -212,7 +199,6 @@ static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp,
|
||||
if (!val.getType().isa<RT::FutureType>()) {
|
||||
OpBuilder::InsertionGuard guard(builder);
|
||||
Type futType = RT::FutureType::get(val.getType());
|
||||
Value memrefCloned;
|
||||
|
||||
// Find out if this value is needed in any other task
|
||||
SmallVector<Operation *, 2> taskOps;
|
||||
@@ -224,20 +210,10 @@ static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp,
|
||||
if (first->getBlock() == op->getBlock() && op->isBeforeInBlock(first))
|
||||
first = op;
|
||||
builder.setInsertionPoint(first);
|
||||
|
||||
// If we are building a future on a MemRef, then we need to clone
|
||||
// the memref in order to allow the deallocation pass which does
|
||||
// not synchronize with task execution.
|
||||
if (val.getType().isa<mlir::MemRefType>()) {
|
||||
memrefCloned = builder.create<arith::ConstantOp>(
|
||||
val.getLoc(), builder.getI64IntegerAttr(1));
|
||||
} else {
|
||||
memrefCloned = builder.create<arith::ConstantOp>(
|
||||
val.getLoc(), builder.getI64IntegerAttr(0));
|
||||
}
|
||||
|
||||
auto mrf = builder.create<RT::MakeReadyFutureOp>(val.getLoc(), futType,
|
||||
val, memrefCloned);
|
||||
auto mrf = builder.create<RT::MakeReadyFutureOp>(
|
||||
val.getLoc(), futType, val,
|
||||
builder.create<arith::ConstantOp>(val.getLoc(),
|
||||
builder.getI64IntegerAttr(0)));
|
||||
replaceAllUsesInDFTsInRegionWith(val, mrf, opBody);
|
||||
}
|
||||
}
|
||||
@@ -246,7 +222,7 @@ static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp,
|
||||
// DataflowTaskOp. This also includes the necessary handling of
|
||||
// operands and results (conversion to/from futures and propagation).
|
||||
SmallVector<Value, 4> catOperands;
|
||||
int size = 3 + DFTOp.getNumResults() * 3 + DFTOp.getNumOperands() * 3;
|
||||
int size = 3 + DFTOp.getNumResults() + DFTOp.getNumOperands();
|
||||
catOperands.reserve(size);
|
||||
auto fnptr = builder.create<mlir::func::ConstantOp>(
|
||||
DFTOp.getLoc(), workFunction.getFunctionType(),
|
||||
@@ -258,12 +234,6 @@ static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp,
|
||||
catOperands.push_back(fnptr.getResult());
|
||||
catOperands.push_back(numIns.getResult());
|
||||
catOperands.push_back(numOuts.getResult());
|
||||
for (auto operand : DFTOp.getOperands()) {
|
||||
auto op_size = getTaskArgumentSizeAndType(operand, DFTOp.getLoc(), builder);
|
||||
catOperands.push_back(operand);
|
||||
catOperands.push_back(op_size.first);
|
||||
catOperands.push_back(op_size.second);
|
||||
}
|
||||
|
||||
// We need to adjust the results for the CreateAsyncTaskOp which
|
||||
// are the work function's returns through pointers passed as
|
||||
@@ -276,11 +246,11 @@ static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp,
|
||||
Type futType = RT::PointerType::get(RT::FutureType::get(result.getType()));
|
||||
auto brpp = builder.create<RT::BuildReturnPtrPlaceholderOp>(DFTOp.getLoc(),
|
||||
futType);
|
||||
auto op_size = getTaskArgumentSizeAndType(result, DFTOp.getLoc(), builder);
|
||||
map.map(result, brpp->getResult(0));
|
||||
catOperands.push_back(brpp->getResult(0));
|
||||
catOperands.push_back(op_size.first);
|
||||
catOperands.push_back(op_size.second);
|
||||
}
|
||||
for (auto operand : DFTOp.getOperands()) {
|
||||
catOperands.push_back(operand);
|
||||
}
|
||||
builder.create<RT::CreateAsyncTaskOp>(
|
||||
DFTOp.getLoc(),
|
||||
@@ -341,59 +311,6 @@ static func::FuncOp getCalledFunction(CallOpInterface callOp) {
|
||||
SymbolTable::lookupNearestSymbolFrom(callOp, sym));
|
||||
}
|
||||
|
||||
static void propagateMemRefLayoutInDFTs(RT::DataflowTaskOp op, Value val,
|
||||
Value newval) {
|
||||
for (auto &use : llvm::make_early_inc_range(val.getUses()))
|
||||
if (use.getOwner()->getParentOfType<RT::DataflowTaskOp>() != nullptr) {
|
||||
OpBuilder builder(use.getOwner());
|
||||
Value cast_newval = builder.create<mlir::memref::CastOp>(
|
||||
val.getLoc(), val.getType(), newval);
|
||||
use.set(cast_newval);
|
||||
}
|
||||
}
|
||||
|
||||
static void cloneMemRefTaskArgumentsWithIdentityMaps(RT::DataflowTaskOp op) {
|
||||
OpBuilder builder(op);
|
||||
for (Value val : op.getOperands()) {
|
||||
if (val.getType().isa<mlir::MemRefType>()) {
|
||||
OpBuilder::InsertionGuard guard(builder);
|
||||
|
||||
// Find out if this memref is needed in any other task to clone
|
||||
// before all uses
|
||||
SmallVector<Operation *, 2> taskOps;
|
||||
for (auto &use : val.getUses())
|
||||
if (isa<RT::DataflowTaskOp>(use.getOwner()))
|
||||
taskOps.push_back(use.getOwner());
|
||||
Operation *first = op;
|
||||
for (auto op : taskOps)
|
||||
if (first->getBlock() == op->getBlock() && op->isBeforeInBlock(first))
|
||||
first = op;
|
||||
builder.setInsertionPoint(first);
|
||||
|
||||
// Get the type of memref that we will clone. In case this is
|
||||
// a subview, we discard the mapping so we copy to a contiguous
|
||||
// layout which pre-serializes this.
|
||||
MemRefType mrType_base = val.getType().dyn_cast<mlir::MemRefType>();
|
||||
MemRefType mrType = mrType_base;
|
||||
if (!mrType_base.getLayout().isIdentity()) {
|
||||
unsigned rank = mrType_base.getRank();
|
||||
mrType = MemRefType::Builder(mrType_base)
|
||||
.setShape(mrType_base.getShape())
|
||||
.setLayout(AffineMapAttr::get(
|
||||
builder.getMultiDimIdentityMap(rank)));
|
||||
}
|
||||
Value newval = builder.create<mlir::memref::AllocOp>(val.getLoc(), mrType)
|
||||
.getResult();
|
||||
builder.create<mlir::memref::CopyOp>(val.getLoc(), val, newval);
|
||||
// Value cast_newval = builder.create<mlir::memref::CastOp>(val.getLoc(),
|
||||
// mrType_base, newval);
|
||||
replaceAllUsesInDFTsInRegionWith(
|
||||
val, newval, op->getParentOfType<func::FuncOp>().getBody());
|
||||
propagateMemRefLayoutInDFTs(op, val, newval);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// For documentation see Autopar.td
|
||||
struct LowerDataflowTasksPass
|
||||
: public LowerDataflowTasksBase<LowerDataflowTasksPass> {
|
||||
@@ -413,8 +330,8 @@ struct LowerDataflowTasksPass
|
||||
SymbolTable symbolTable = mlir::SymbolTable::getNearestSymbolTable(func);
|
||||
SmallVector<std::pair<RT::DataflowTaskOp, func::FuncOp>, 4> outliningMap;
|
||||
|
||||
// Outline DataflowTaskOp bodies to work functions
|
||||
func.walk([&](RT::DataflowTaskOp op) {
|
||||
cloneMemRefTaskArgumentsWithIdentityMaps(op);
|
||||
auto workFunctionName =
|
||||
Twine("_dfr_DFT_work_function__") +
|
||||
Twine(op->getParentOfType<func::FuncOp>().getName()) +
|
||||
@@ -423,15 +340,15 @@ struct LowerDataflowTasksPass
|
||||
outlineWorkFunction(op, workFunctionName.str());
|
||||
outliningMap.push_back(
|
||||
std::pair<RT::DataflowTaskOp, func::FuncOp>(op, outlinedFunc));
|
||||
workFunctions.push_back(outlinedFunc);
|
||||
symbolTable.insert(outlinedFunc);
|
||||
workFunctions.push_back(outlinedFunc);
|
||||
return WalkResult::advance();
|
||||
});
|
||||
|
||||
// Lower the DF task ops to RT dialect ops.
|
||||
for (auto mapping : outliningMap)
|
||||
lowerDataflowTaskOp(mapping.first, mapping.second);
|
||||
|
||||
// Gather all entry points (assuming no recursive calls to entry points)
|
||||
// Main is always an entry-point - otherwise check if this
|
||||
// function is called within the module. TODO: we assume no
|
||||
// recursion.
|
||||
@@ -449,17 +366,6 @@ struct LowerDataflowTasksPass
|
||||
});
|
||||
|
||||
for (auto entryPoint : entryPoints) {
|
||||
// Check if this entry point uses a context - do this before we
|
||||
// remove arguments in remote nodes
|
||||
int ctxIndex = -1;
|
||||
for (auto arg : llvm::enumerate(entryPoint.getArguments()))
|
||||
if (arg.value()
|
||||
.getType()
|
||||
.isa<mlir::concretelang::Concrete::ContextType>()) {
|
||||
ctxIndex = arg.index();
|
||||
break;
|
||||
}
|
||||
|
||||
// If this is a JIT invocation and we're not on the root node,
|
||||
// we do not need to do any computation, only register all work
|
||||
// functions with the runtime system
|
||||
@@ -486,49 +392,6 @@ struct LowerDataflowTasksPass
|
||||
// runtime.
|
||||
for (auto wf : workFunctions)
|
||||
registerWorkFunction(entryPoint, wf);
|
||||
|
||||
// Issue _dfr_start/stop calls for this function
|
||||
OpBuilder builder(entryPoint.getBody());
|
||||
builder.setInsertionPointToStart(&entryPoint.getBody().front());
|
||||
int useDFR = (workFunctions.empty()) ? 0 : 1;
|
||||
Value useDFRVal = builder.create<arith::ConstantOp>(
|
||||
entryPoint.getLoc(), builder.getI64IntegerAttr(useDFR));
|
||||
|
||||
if (ctxIndex >= 0) {
|
||||
auto startFunTy =
|
||||
(dfr::_dfr_is_root_node())
|
||||
? mlir::FunctionType::get(
|
||||
entryPoint->getContext(),
|
||||
{useDFRVal.getType(),
|
||||
entryPoint.getArgument(ctxIndex).getType()},
|
||||
{})
|
||||
: mlir::FunctionType::get(entryPoint->getContext(),
|
||||
{useDFRVal.getType()}, {});
|
||||
(void)insertForwardDeclaration(entryPoint, builder, "_dfr_start_c",
|
||||
startFunTy);
|
||||
(dfr::_dfr_is_root_node())
|
||||
? builder.create<mlir::func::CallOp>(
|
||||
entryPoint.getLoc(), "_dfr_start_c", mlir::TypeRange(),
|
||||
mlir::ValueRange(
|
||||
{useDFRVal, entryPoint.getArgument(ctxIndex)}))
|
||||
: builder.create<mlir::func::CallOp>(entryPoint.getLoc(),
|
||||
"_dfr_start_c",
|
||||
mlir::TypeRange(), useDFRVal);
|
||||
} else {
|
||||
auto startFunTy = mlir::FunctionType::get(entryPoint->getContext(),
|
||||
{useDFRVal.getType()}, {});
|
||||
(void)insertForwardDeclaration(entryPoint, builder, "_dfr_start",
|
||||
startFunTy);
|
||||
builder.create<mlir::func::CallOp>(entryPoint.getLoc(), "_dfr_start",
|
||||
mlir::TypeRange(), useDFRVal);
|
||||
}
|
||||
builder.setInsertionPoint(entryPoint.getBody().back().getTerminator());
|
||||
auto stopFunTy = mlir::FunctionType::get(entryPoint->getContext(),
|
||||
{useDFRVal.getType()}, {});
|
||||
(void)insertForwardDeclaration(entryPoint, builder, "_dfr_stop",
|
||||
stopFunTy);
|
||||
builder.create<mlir::func::CallOp>(entryPoint.getLoc(), "_dfr_stop",
|
||||
mlir::TypeRange(), useDFRVal);
|
||||
}
|
||||
}
|
||||
LowerDataflowTasksPass(bool debug) : debug(debug){};
|
||||
@@ -544,6 +407,188 @@ std::unique_ptr<mlir::Pass> createLowerDataflowTasksPass(bool debug) {
|
||||
|
||||
namespace {
|
||||
|
||||
// For documentation see Autopar.td
|
||||
struct StartStopPass : public StartStopBase<StartStopPass> {
|
||||
|
||||
void runOnOperation() override {
|
||||
auto module = getOperation();
|
||||
int useDFR = 0;
|
||||
SmallVector<func::FuncOp, 1> entryPoints;
|
||||
|
||||
module.walk([&](mlir::func::FuncOp func) {
|
||||
// Do not add start/stop to work functions - but if any are
|
||||
// present, then we need to activate the runtime
|
||||
if (func->getAttr("_dfr_work_function_attribute")) {
|
||||
useDFR = 1;
|
||||
} else {
|
||||
// Main is always an entry-point - otherwise check if this
|
||||
// function is called within the module. TODO: we assume no
|
||||
// recursion.
|
||||
if (func.getName() == "main")
|
||||
entryPoints.push_back(func);
|
||||
else {
|
||||
bool found = false;
|
||||
module.walk([&](mlir::func::CallOp op) {
|
||||
if (getCalledFunction(op) == func)
|
||||
found = true;
|
||||
});
|
||||
if (!found)
|
||||
entryPoints.push_back(func);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
for (auto entryPoint : entryPoints) {
|
||||
// Issue _dfr_start/stop calls for this function
|
||||
OpBuilder builder(entryPoint.getBody());
|
||||
builder.setInsertionPointToStart(&entryPoint.getBody().front());
|
||||
Value useDFRVal = builder.create<arith::ConstantOp>(
|
||||
entryPoint.getLoc(), builder.getI64IntegerAttr(useDFR));
|
||||
|
||||
// Check if this entry point uses a context
|
||||
Value ctx = nullptr;
|
||||
if (dfr::_dfr_is_root_node())
|
||||
for (auto arg : llvm::enumerate(entryPoint.getArguments()))
|
||||
if (arg.value()
|
||||
.getType()
|
||||
.isa<mlir::concretelang::Concrete::ContextType>()) {
|
||||
ctx = arg.value();
|
||||
break;
|
||||
}
|
||||
if (!ctx)
|
||||
ctx = builder.create<arith::ConstantOp>(entryPoint.getLoc(),
|
||||
builder.getI64IntegerAttr(0));
|
||||
|
||||
auto startFunTy = mlir::FunctionType::get(
|
||||
entryPoint->getContext(), {useDFRVal.getType(), ctx.getType()}, {});
|
||||
(void)insertForwardDeclaration(entryPoint, builder, "_dfr_start",
|
||||
startFunTy);
|
||||
builder.create<mlir::func::CallOp>(entryPoint.getLoc(), "_dfr_start",
|
||||
mlir::TypeRange(),
|
||||
mlir::ValueRange({useDFRVal, ctx}));
|
||||
builder.setInsertionPoint(entryPoint.getBody().back().getTerminator());
|
||||
auto stopFunTy = mlir::FunctionType::get(entryPoint->getContext(),
|
||||
{useDFRVal.getType()}, {});
|
||||
(void)insertForwardDeclaration(entryPoint, builder, "_dfr_stop",
|
||||
stopFunTy);
|
||||
builder.create<mlir::func::CallOp>(entryPoint.getLoc(), "_dfr_stop",
|
||||
mlir::TypeRange(), useDFRVal);
|
||||
}
|
||||
}
|
||||
StartStopPass(bool debug) : debug(debug){};
|
||||
|
||||
protected:
|
||||
bool debug;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<mlir::Pass> createStartStopPass(bool debug) {
|
||||
return std::make_unique<StartStopPass>(debug);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// For documentation see Autopar.td
|
||||
struct FinalizeTaskCreationPass
|
||||
: public FinalizeTaskCreationBase<FinalizeTaskCreationPass> {
|
||||
|
||||
void runOnOperation() override {
|
||||
auto module = getOperation();
|
||||
std::vector<Operation *> ops;
|
||||
|
||||
module.walk([&](RT::CreateAsyncTaskOp catOp) {
|
||||
OpBuilder builder(catOp);
|
||||
SmallVector<Value, 4> operands;
|
||||
|
||||
// Determine if this task needs a runtime context
|
||||
Value ctx = nullptr;
|
||||
SymbolRefAttr sym =
|
||||
catOp->getAttr("workfn").dyn_cast_or_null<SymbolRefAttr>();
|
||||
assert(sym && "Work function symbol attribute missing.");
|
||||
func::FuncOp workfn = dyn_cast_or_null<func::FuncOp>(
|
||||
SymbolTable::lookupNearestSymbolFrom(catOp, sym));
|
||||
assert(workfn && "Task work function missing.");
|
||||
if (workfn.getNumArguments() > catOp.getNumOperands() - 3)
|
||||
ctx = *catOp->getParentOfType<func::FuncOp>().getArguments().rbegin();
|
||||
else
|
||||
ctx = builder.create<arith::ConstantOp>(catOp.getLoc(),
|
||||
builder.getI64IntegerAttr(0));
|
||||
int index = 0;
|
||||
for (auto op : catOp.getOperands()) {
|
||||
operands.push_back(op);
|
||||
// Add index in second position - in all cases to avoid
|
||||
// checking if needed. It can be null when not relevant.
|
||||
if (index == 0)
|
||||
operands.push_back(ctx);
|
||||
// First three operands are the function pointer, number inputs
|
||||
// and number outputs - nothing further to do.
|
||||
if (++index <= 3)
|
||||
continue;
|
||||
auto op_size = getTaskArgumentSizeAndType(op, catOp.getLoc(), builder);
|
||||
operands.push_back(op_size.first);
|
||||
operands.push_back(op_size.second);
|
||||
}
|
||||
|
||||
builder.create<RT::CreateAsyncTaskOp>(catOp.getLoc(), sym, operands);
|
||||
ops.push_back(catOp);
|
||||
});
|
||||
for (auto op : ops) {
|
||||
op->erase();
|
||||
}
|
||||
|
||||
// If we are building a future on a MemRef, we need to flatten it.
|
||||
|
||||
// TODO: the performance of shared memory can be improved by
|
||||
// allowing view-like access instead of cloning, but memory
|
||||
// deallocation needs to be synchronized appropriately
|
||||
module.walk([&](RT::MakeReadyFutureOp op) {
|
||||
OpBuilder builder(op);
|
||||
|
||||
Value val = op.getOperand(0);
|
||||
Value clone = op.getOperand(1);
|
||||
if (val.getType().isa<mlir::MemRefType>()) {
|
||||
MemRefType mrType_base = val.getType().dyn_cast<mlir::MemRefType>();
|
||||
MemRefType mrType = mrType_base;
|
||||
if (!mrType_base.getLayout().isIdentity()) {
|
||||
unsigned rank = mrType_base.getRank();
|
||||
mrType = MemRefType::Builder(mrType_base)
|
||||
.setShape(mrType_base.getShape())
|
||||
.setLayout(AffineMapAttr::get(
|
||||
builder.getMultiDimIdentityMap(rank)));
|
||||
}
|
||||
// We need to make a copy of this MemRef to allow deallocation
|
||||
// based on refcounting
|
||||
Value newval =
|
||||
builder.create<mlir::memref::AllocOp>(val.getLoc(), mrType)
|
||||
.getResult();
|
||||
builder.create<mlir::memref::CopyOp>(val.getLoc(), val, newval);
|
||||
clone = builder.create<arith::ConstantOp>(op.getLoc(),
|
||||
builder.getI64IntegerAttr(1));
|
||||
op->setOperand(0, newval);
|
||||
op->setOperand(1, clone);
|
||||
}
|
||||
});
|
||||
}
|
||||
FinalizeTaskCreationPass(bool debug) : debug(debug){};
|
||||
|
||||
protected:
|
||||
bool debug;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<mlir::Pass> createFinalizeTaskCreationPass(bool debug) {
|
||||
return std::make_unique<FinalizeTaskCreationPass>(debug);
|
||||
}
|
||||
|
||||
namespace {
|
||||
static void getAliasedUses(Value val, DenseSet<OpOperand *> &aliasedUses) {
|
||||
for (auto &use : val.getUses()) {
|
||||
aliasedUses.insert(&use);
|
||||
if (dyn_cast<ViewLikeOpInterface>(use.getOwner()))
|
||||
getAliasedUses(use.getOwner()->getResult(0), aliasedUses);
|
||||
}
|
||||
}
|
||||
|
||||
// For documentation see Autopar.td
|
||||
struct FixupBufferDeallocationPass
|
||||
: public FixupBufferDeallocationBase<FixupBufferDeallocationPass> {
|
||||
@@ -552,23 +597,21 @@ struct FixupBufferDeallocationPass
|
||||
auto module = getOperation();
|
||||
std::vector<Operation *> ops;
|
||||
|
||||
// All buffers allocated and either made into a future, directly
|
||||
// or as a result of being returned by a task, are managed by the
|
||||
// DFR runtime system's reference counting.
|
||||
module.walk([&](RT::WorkFunctionReturnOp retOp) {
|
||||
for (auto &use :
|
||||
llvm::make_early_inc_range(retOp.getOperands().front().getUses()))
|
||||
if (isa<mlir::memref::DeallocOp>(use.getOwner()))
|
||||
ops.push_back(use.getOwner());
|
||||
module.walk([&](mlir::memref::DeallocOp op) {
|
||||
Value alloc = op.getOperand();
|
||||
DenseSet<OpOperand *> aliasedUses;
|
||||
getAliasedUses(alloc, aliasedUses);
|
||||
|
||||
for (auto use : aliasedUses)
|
||||
if (isa<RT::WorkFunctionReturnOp, RT::MakeReadyFutureOp>(
|
||||
use->getOwner())) {
|
||||
ops.push_back(op);
|
||||
return;
|
||||
}
|
||||
});
|
||||
module.walk([&](RT::MakeReadyFutureOp mrfOp) {
|
||||
for (auto &use :
|
||||
llvm::make_early_inc_range(mrfOp.getOperands().front().getUses()))
|
||||
if (isa<mlir::memref::DeallocOp>(use.getOwner()))
|
||||
ops.push_back(use.getOwner());
|
||||
});
|
||||
for (auto op : ops)
|
||||
for (auto op : ops) {
|
||||
op->erase();
|
||||
}
|
||||
}
|
||||
FixupBufferDeallocationPass(bool debug) : debug(debug){};
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include <mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h>
|
||||
#include <mlir/Dialect/Func/IR/FuncOps.h>
|
||||
#include <mlir/IR/BuiltinTypes.h>
|
||||
#include <mlir/Transforms/RegionUtils.h>
|
||||
|
||||
@@ -21,9 +22,10 @@ using namespace mlir::concretelang::RT;
|
||||
// using namespace mlir::tensor;
|
||||
|
||||
namespace {
|
||||
struct DataflowTaskOpBufferizationInterface
|
||||
struct DerefWorkFunctionArgumentPtrPlaceholderOpBufferizationInterface
|
||||
: public BufferizableOpInterface::ExternalModel<
|
||||
DataflowTaskOpBufferizationInterface, DataflowTaskOp> {
|
||||
DerefWorkFunctionArgumentPtrPlaceholderOpBufferizationInterface,
|
||||
DerefWorkFunctionArgumentPtrPlaceholderOp> {
|
||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||
const AnalysisState &state) const {
|
||||
return false;
|
||||
@@ -44,22 +46,20 @@ struct DataflowTaskOpBufferizationInterface
|
||||
return BufferRelation::None;
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
LogicalResult bufferize(Operation *bop, RewriterBase &rewriter,
|
||||
const BufferizationOptions &options) const {
|
||||
DataflowTaskOp taskOp = cast<DataflowTaskOp>(op);
|
||||
DerefWorkFunctionArgumentPtrPlaceholderOp op =
|
||||
cast<DerefWorkFunctionArgumentPtrPlaceholderOp>(bop);
|
||||
|
||||
auto isTensorType = [](Type t) { return t.isa<TensorType>(); };
|
||||
bool hasTensorResult = llvm::any_of(taskOp.getResultTypes(), isTensorType);
|
||||
bool hasTensorOperand =
|
||||
llvm::any_of(taskOp.getOperandTypes(), isTensorType);
|
||||
bool hasTensorResult = llvm::any_of(op->getResultTypes(), isTensorType);
|
||||
bool hasTensorOperand = llvm::any_of(op->getOperandTypes(), isTensorType);
|
||||
|
||||
if (!hasTensorResult && !hasTensorOperand)
|
||||
return success();
|
||||
|
||||
SmallVector<mlir::Value, 2> newOperands;
|
||||
|
||||
rewriter.setInsertionPoint(taskOp.getBody(), taskOp.getBody()->begin());
|
||||
|
||||
for (OpOperand &opOperand : op->getOpOperands()) {
|
||||
Value oldOperandValue = opOperand.get();
|
||||
|
||||
@@ -72,43 +72,11 @@ struct DataflowTaskOpBufferizationInterface
|
||||
|
||||
Value buffer = bufferOrErr.getValue();
|
||||
newOperands.push_back(buffer);
|
||||
|
||||
Value tensor =
|
||||
rewriter.create<bufferization::ToTensorOp>(buffer.getLoc(), buffer);
|
||||
|
||||
replaceAllUsesInRegionWith(oldOperandValue, tensor,
|
||||
taskOp.getBodyRegion());
|
||||
} else {
|
||||
newOperands.push_back(opOperand.get());
|
||||
}
|
||||
}
|
||||
|
||||
if (hasTensorResult) {
|
||||
WalkResult wr = taskOp.walk([&](DataflowYieldOp yield) {
|
||||
SmallVector<Value, 2> yieldValues;
|
||||
|
||||
for (OpOperand &yieldOperand : yield.getOperation()->getOpOperands())
|
||||
if (yieldOperand.get().getType().isa<TensorType>()) {
|
||||
FailureOr<Value> bufferOrErr =
|
||||
bufferization::getBuffer(rewriter, yieldOperand.get(), options);
|
||||
|
||||
if (failed(bufferOrErr))
|
||||
return WalkResult::interrupt();
|
||||
|
||||
yieldValues.push_back(bufferOrErr.getValue());
|
||||
} else {
|
||||
yieldValues.push_back(yieldOperand.get());
|
||||
}
|
||||
|
||||
rewriter.setInsertionPointAfter(yield);
|
||||
rewriter.replaceOpWithNewOp<DataflowYieldOp>(yield.getOperation(),
|
||||
yieldValues);
|
||||
|
||||
return WalkResult::advance();
|
||||
});
|
||||
|
||||
if (wr.wasInterrupted())
|
||||
return failure();
|
||||
}
|
||||
|
||||
SmallVector<mlir::Type, 2> newResultTypes;
|
||||
|
||||
for (OpResult res : op->getResults()) {
|
||||
@@ -120,17 +88,239 @@ struct DataflowTaskOpBufferizationInterface
|
||||
}
|
||||
}
|
||||
|
||||
rewriter.setInsertionPoint(taskOp);
|
||||
DataflowTaskOp newTaskOp = rewriter.create<DataflowTaskOp>(
|
||||
taskOp.getLoc(), newResultTypes, newOperands);
|
||||
rewriter.setInsertionPoint(op);
|
||||
DerefWorkFunctionArgumentPtrPlaceholderOp newOp =
|
||||
rewriter.create<DerefWorkFunctionArgumentPtrPlaceholderOp>(
|
||||
op.getLoc(), newResultTypes, newOperands);
|
||||
|
||||
newTaskOp.getRegion().takeBody(taskOp.getRegion());
|
||||
|
||||
replaceOpWithBufferizedValues(rewriter, op, newTaskOp->getResults());
|
||||
replaceOpWithBufferizedValues(rewriter, op, newOp->getResults());
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct MakeReadyFutureOpBufferizationInterface
|
||||
: public BufferizableOpInterface::ExternalModel<
|
||||
MakeReadyFutureOpBufferizationInterface, MakeReadyFutureOp> {
|
||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||
const AnalysisState &state) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
||||
const AnalysisState &state) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||
const AnalysisState &state) const {
|
||||
return {};
|
||||
}
|
||||
|
||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||
const AnalysisState &state) const {
|
||||
return BufferRelation::None;
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation *bop, RewriterBase &rewriter,
|
||||
const BufferizationOptions &options) const {
|
||||
MakeReadyFutureOp op = cast<MakeReadyFutureOp>(bop);
|
||||
|
||||
auto isTensorType = [](Type t) { return t.isa<TensorType>(); };
|
||||
bool hasTensorResult = llvm::any_of(op->getResultTypes(), isTensorType);
|
||||
bool hasTensorOperand = llvm::any_of(op->getOperandTypes(), isTensorType);
|
||||
|
||||
if (!hasTensorResult && !hasTensorOperand)
|
||||
return success();
|
||||
|
||||
SmallVector<mlir::Value, 2> newOperands;
|
||||
|
||||
for (OpOperand &opOperand : op->getOpOperands()) {
|
||||
Value oldOperandValue = opOperand.get();
|
||||
|
||||
if (oldOperandValue.getType().isa<TensorType>()) {
|
||||
FailureOr<Value> bufferOrErr =
|
||||
bufferization::getBuffer(rewriter, opOperand.get(), options);
|
||||
|
||||
if (failed(bufferOrErr))
|
||||
return failure();
|
||||
|
||||
Value buffer = bufferOrErr.getValue();
|
||||
newOperands.push_back(buffer);
|
||||
} else {
|
||||
newOperands.push_back(opOperand.get());
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<mlir::Type, 2> newResultTypes;
|
||||
|
||||
for (OpResult res : op->getResults()) {
|
||||
if (TensorType t = res.getType().dyn_cast<TensorType>()) {
|
||||
BaseMemRefType memrefType = getMemRefType(t, options);
|
||||
newResultTypes.push_back(memrefType);
|
||||
} else {
|
||||
newResultTypes.push_back(res.getType());
|
||||
}
|
||||
}
|
||||
|
||||
rewriter.setInsertionPoint(op);
|
||||
MakeReadyFutureOp newOp = rewriter.create<MakeReadyFutureOp>(
|
||||
op.getLoc(), newResultTypes, newOperands);
|
||||
|
||||
replaceOpWithBufferizedValues(rewriter, op, newOp->getResults());
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct WorkFunctionReturnOpBufferizationInterface
|
||||
: public BufferizableOpInterface::ExternalModel<
|
||||
WorkFunctionReturnOpBufferizationInterface, WorkFunctionReturnOp> {
|
||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||
const AnalysisState &state) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
||||
const AnalysisState &state) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||
const AnalysisState &state) const {
|
||||
return {};
|
||||
}
|
||||
|
||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||
const AnalysisState &state) const {
|
||||
return BufferRelation::None;
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation *bop, RewriterBase &rewriter,
|
||||
const BufferizationOptions &options) const {
|
||||
WorkFunctionReturnOp op = cast<WorkFunctionReturnOp>(bop);
|
||||
|
||||
auto isTensorType = [](Type t) { return t.isa<TensorType>(); };
|
||||
bool hasTensorResult = llvm::any_of(op->getResultTypes(), isTensorType);
|
||||
bool hasTensorOperand = llvm::any_of(op->getOperandTypes(), isTensorType);
|
||||
|
||||
if (!hasTensorResult && !hasTensorOperand)
|
||||
return success();
|
||||
|
||||
SmallVector<mlir::Value, 2> newOperands;
|
||||
|
||||
for (OpOperand &opOperand : op->getOpOperands()) {
|
||||
Value oldOperandValue = opOperand.get();
|
||||
|
||||
if (oldOperandValue.getType().isa<TensorType>()) {
|
||||
FailureOr<Value> bufferOrErr =
|
||||
bufferization::getBuffer(rewriter, opOperand.get(), options);
|
||||
|
||||
if (failed(bufferOrErr))
|
||||
return failure();
|
||||
|
||||
Value buffer = bufferOrErr.getValue();
|
||||
newOperands.push_back(buffer);
|
||||
} else {
|
||||
newOperands.push_back(opOperand.get());
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<mlir::Type, 2> newResultTypes;
|
||||
|
||||
for (OpResult res : op->getResults()) {
|
||||
if (TensorType t = res.getType().dyn_cast<TensorType>()) {
|
||||
BaseMemRefType memrefType = getMemRefType(t, options);
|
||||
newResultTypes.push_back(memrefType);
|
||||
} else {
|
||||
newResultTypes.push_back(res.getType());
|
||||
}
|
||||
}
|
||||
|
||||
rewriter.setInsertionPoint(op);
|
||||
WorkFunctionReturnOp newOp = rewriter.create<WorkFunctionReturnOp>(
|
||||
op.getLoc(), newResultTypes, newOperands);
|
||||
|
||||
replaceOpWithBufferizedValues(rewriter, op, newOp->getResults());
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct AwaitFutureOpBufferizationInterface
|
||||
: public BufferizableOpInterface::ExternalModel<
|
||||
AwaitFutureOpBufferizationInterface, AwaitFutureOp> {
|
||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||
const AnalysisState &state) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
||||
const AnalysisState &state) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||
const AnalysisState &state) const {
|
||||
return {};
|
||||
}
|
||||
|
||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||
const AnalysisState &state) const {
|
||||
return BufferRelation::None;
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation *bop, RewriterBase &rewriter,
|
||||
const BufferizationOptions &options) const {
|
||||
AwaitFutureOp op = cast<AwaitFutureOp>(bop);
|
||||
|
||||
auto isTensorType = [](Type t) { return t.isa<TensorType>(); };
|
||||
bool hasTensorResult = llvm::any_of(op->getResultTypes(), isTensorType);
|
||||
bool hasTensorOperand = llvm::any_of(op->getOperandTypes(), isTensorType);
|
||||
|
||||
if (!hasTensorResult && !hasTensorOperand)
|
||||
return success();
|
||||
|
||||
SmallVector<mlir::Value, 2> newOperands;
|
||||
|
||||
for (OpOperand &opOperand : op->getOpOperands()) {
|
||||
Value oldOperandValue = opOperand.get();
|
||||
|
||||
if (oldOperandValue.getType().isa<TensorType>()) {
|
||||
FailureOr<Value> bufferOrErr =
|
||||
bufferization::getBuffer(rewriter, opOperand.get(), options);
|
||||
|
||||
if (failed(bufferOrErr))
|
||||
return failure();
|
||||
|
||||
Value buffer = bufferOrErr.getValue();
|
||||
newOperands.push_back(buffer);
|
||||
} else {
|
||||
newOperands.push_back(opOperand.get());
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<mlir::Type, 2> newResultTypes;
|
||||
|
||||
for (OpResult res : op->getResults()) {
|
||||
if (TensorType t = res.getType().dyn_cast<TensorType>()) {
|
||||
BaseMemRefType memrefType = getMemRefType(t, options);
|
||||
newResultTypes.push_back(memrefType);
|
||||
} else {
|
||||
newResultTypes.push_back(res.getType());
|
||||
}
|
||||
}
|
||||
|
||||
rewriter.setInsertionPoint(op);
|
||||
AwaitFutureOp newOp = rewriter.create<AwaitFutureOp>(
|
||||
op.getLoc(), newResultTypes, newOperands);
|
||||
|
||||
replaceOpWithBufferizedValues(rewriter, op, newOp->getResults());
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace mlir {
|
||||
@@ -138,7 +328,13 @@ namespace concretelang {
|
||||
namespace RT {
|
||||
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) {
|
||||
registry.addExtension(+[](MLIRContext *ctx, RTDialect *dialect) {
|
||||
DataflowTaskOp::attachInterface<DataflowTaskOpBufferizationInterface>(*ctx);
|
||||
DerefWorkFunctionArgumentPtrPlaceholderOp::attachInterface<
|
||||
DerefWorkFunctionArgumentPtrPlaceholderOpBufferizationInterface>(*ctx);
|
||||
AwaitFutureOp::attachInterface<AwaitFutureOpBufferizationInterface>(*ctx);
|
||||
MakeReadyFutureOp::attachInterface<MakeReadyFutureOpBufferizationInterface>(
|
||||
*ctx);
|
||||
WorkFunctionReturnOp::attachInterface<
|
||||
WorkFunctionReturnOpBufferizationInterface>(*ctx);
|
||||
});
|
||||
}
|
||||
} // namespace RT
|
||||
|
||||
@@ -93,8 +93,8 @@ static inline size_t _dfr_find_next_execution_locality() {
|
||||
/// hpx::future<void*> and the size of data within the future. After
|
||||
/// that come NUM_OUTPUTS pairs of hpx::future<void*>* and size_t for
|
||||
/// the returns.
|
||||
void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
...) {
|
||||
void _dfr_create_async_task(wfnptr wfn, void *ctx, size_t num_params,
|
||||
size_t num_outputs, ...) {
|
||||
std::vector<void *> refcounted_futures;
|
||||
std::vector<size_t> param_sizes;
|
||||
std::vector<uint64_t> param_types;
|
||||
@@ -104,16 +104,16 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
|
||||
va_list args;
|
||||
va_start(args, num_outputs);
|
||||
for (size_t i = 0; i < num_params; ++i) {
|
||||
refcounted_futures.push_back(va_arg(args, void *));
|
||||
param_sizes.push_back(va_arg(args, uint64_t));
|
||||
param_types.push_back(va_arg(args, uint64_t));
|
||||
}
|
||||
for (size_t i = 0; i < num_outputs; ++i) {
|
||||
outputs.push_back(va_arg(args, void *));
|
||||
output_sizes.push_back(va_arg(args, uint64_t));
|
||||
output_types.push_back(va_arg(args, uint64_t));
|
||||
}
|
||||
for (size_t i = 0; i < num_params; ++i) {
|
||||
refcounted_futures.push_back(va_arg(args, void *));
|
||||
param_sizes.push_back(va_arg(args, uint64_t));
|
||||
param_types.push_back(va_arg(args, uint64_t));
|
||||
}
|
||||
va_end(args);
|
||||
|
||||
// Take a reference on each future argument
|
||||
@@ -140,12 +140,12 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
case 0:
|
||||
oodf = std::move(hpx::dataflow(
|
||||
[wfnname, param_sizes, param_types, output_sizes, output_types,
|
||||
gcc_target]()
|
||||
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
|
||||
gcc_target,
|
||||
ctx]() -> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
|
||||
std::vector<void *> params = {};
|
||||
mlir::concretelang::dfr::OpaqueInputData oid(
|
||||
wfnname, params, param_sizes, param_types, output_sizes,
|
||||
output_types);
|
||||
output_types, ctx);
|
||||
return gcc_target->execute_task(oid);
|
||||
}));
|
||||
break;
|
||||
@@ -153,12 +153,12 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
case 1:
|
||||
oodf = std::move(hpx::dataflow(
|
||||
[wfnname, param_sizes, param_types, output_sizes, output_types,
|
||||
gcc_target](hpx::shared_future<void *> param0)
|
||||
gcc_target, ctx](hpx::shared_future<void *> param0)
|
||||
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
|
||||
std::vector<void *> params = {param0.get()};
|
||||
mlir::concretelang::dfr::OpaqueInputData oid(
|
||||
wfnname, params, param_sizes, param_types, output_sizes,
|
||||
output_types);
|
||||
output_types, ctx);
|
||||
return gcc_target->execute_task(oid);
|
||||
},
|
||||
*((dfr_refcounted_future_p)refcounted_futures[0])->future));
|
||||
@@ -167,13 +167,13 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
case 2:
|
||||
oodf = std::move(hpx::dataflow(
|
||||
[wfnname, param_sizes, param_types, output_sizes, output_types,
|
||||
gcc_target](hpx::shared_future<void *> param0,
|
||||
hpx::shared_future<void *> param1)
|
||||
gcc_target, ctx](hpx::shared_future<void *> param0,
|
||||
hpx::shared_future<void *> param1)
|
||||
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
|
||||
std::vector<void *> params = {param0.get(), param1.get()};
|
||||
mlir::concretelang::dfr::OpaqueInputData oid(
|
||||
wfnname, params, param_sizes, param_types, output_sizes,
|
||||
output_types);
|
||||
output_types, ctx);
|
||||
return gcc_target->execute_task(oid);
|
||||
},
|
||||
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
|
||||
@@ -183,15 +183,15 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
case 3:
|
||||
oodf = std::move(hpx::dataflow(
|
||||
[wfnname, param_sizes, param_types, output_sizes, output_types,
|
||||
gcc_target](hpx::shared_future<void *> param0,
|
||||
hpx::shared_future<void *> param1,
|
||||
hpx::shared_future<void *> param2)
|
||||
gcc_target, ctx](hpx::shared_future<void *> param0,
|
||||
hpx::shared_future<void *> param1,
|
||||
hpx::shared_future<void *> param2)
|
||||
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
|
||||
std::vector<void *> params = {param0.get(), param1.get(),
|
||||
param2.get()};
|
||||
mlir::concretelang::dfr::OpaqueInputData oid(
|
||||
wfnname, params, param_sizes, param_types, output_sizes,
|
||||
output_types);
|
||||
output_types, ctx);
|
||||
return gcc_target->execute_task(oid);
|
||||
},
|
||||
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
|
||||
@@ -202,16 +202,16 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
case 4:
|
||||
oodf = std::move(hpx::dataflow(
|
||||
[wfnname, param_sizes, param_types, output_sizes, output_types,
|
||||
gcc_target](hpx::shared_future<void *> param0,
|
||||
hpx::shared_future<void *> param1,
|
||||
hpx::shared_future<void *> param2,
|
||||
hpx::shared_future<void *> param3)
|
||||
gcc_target, ctx](hpx::shared_future<void *> param0,
|
||||
hpx::shared_future<void *> param1,
|
||||
hpx::shared_future<void *> param2,
|
||||
hpx::shared_future<void *> param3)
|
||||
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
|
||||
std::vector<void *> params = {param0.get(), param1.get(),
|
||||
param2.get(), param3.get()};
|
||||
mlir::concretelang::dfr::OpaqueInputData oid(
|
||||
wfnname, params, param_sizes, param_types, output_sizes,
|
||||
output_types);
|
||||
output_types, ctx);
|
||||
return gcc_target->execute_task(oid);
|
||||
},
|
||||
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
|
||||
@@ -223,18 +223,18 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
case 5:
|
||||
oodf = std::move(hpx::dataflow(
|
||||
[wfnname, param_sizes, param_types, output_sizes, output_types,
|
||||
gcc_target](hpx::shared_future<void *> param0,
|
||||
hpx::shared_future<void *> param1,
|
||||
hpx::shared_future<void *> param2,
|
||||
hpx::shared_future<void *> param3,
|
||||
hpx::shared_future<void *> param4)
|
||||
gcc_target, ctx](hpx::shared_future<void *> param0,
|
||||
hpx::shared_future<void *> param1,
|
||||
hpx::shared_future<void *> param2,
|
||||
hpx::shared_future<void *> param3,
|
||||
hpx::shared_future<void *> param4)
|
||||
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
|
||||
std::vector<void *> params = {param0.get(), param1.get(),
|
||||
param2.get(), param3.get(),
|
||||
param4.get()};
|
||||
mlir::concretelang::dfr::OpaqueInputData oid(
|
||||
wfnname, params, param_sizes, param_types, output_sizes,
|
||||
output_types);
|
||||
output_types, ctx);
|
||||
return gcc_target->execute_task(oid);
|
||||
},
|
||||
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
|
||||
@@ -247,19 +247,19 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
case 6:
|
||||
oodf = std::move(hpx::dataflow(
|
||||
[wfnname, param_sizes, param_types, output_sizes, output_types,
|
||||
gcc_target](hpx::shared_future<void *> param0,
|
||||
hpx::shared_future<void *> param1,
|
||||
hpx::shared_future<void *> param2,
|
||||
hpx::shared_future<void *> param3,
|
||||
hpx::shared_future<void *> param4,
|
||||
hpx::shared_future<void *> param5)
|
||||
gcc_target, ctx](hpx::shared_future<void *> param0,
|
||||
hpx::shared_future<void *> param1,
|
||||
hpx::shared_future<void *> param2,
|
||||
hpx::shared_future<void *> param3,
|
||||
hpx::shared_future<void *> param4,
|
||||
hpx::shared_future<void *> param5)
|
||||
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
|
||||
std::vector<void *> params = {param0.get(), param1.get(),
|
||||
param2.get(), param3.get(),
|
||||
param4.get(), param5.get()};
|
||||
mlir::concretelang::dfr::OpaqueInputData oid(
|
||||
wfnname, params, param_sizes, param_types, output_sizes,
|
||||
output_types);
|
||||
output_types, ctx);
|
||||
return gcc_target->execute_task(oid);
|
||||
},
|
||||
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
|
||||
@@ -273,20 +273,20 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
case 7:
|
||||
oodf = std::move(hpx::dataflow(
|
||||
[wfnname, param_sizes, param_types, output_sizes, output_types,
|
||||
gcc_target](hpx::shared_future<void *> param0,
|
||||
hpx::shared_future<void *> param1,
|
||||
hpx::shared_future<void *> param2,
|
||||
hpx::shared_future<void *> param3,
|
||||
hpx::shared_future<void *> param4,
|
||||
hpx::shared_future<void *> param5,
|
||||
hpx::shared_future<void *> param6)
|
||||
gcc_target, ctx](hpx::shared_future<void *> param0,
|
||||
hpx::shared_future<void *> param1,
|
||||
hpx::shared_future<void *> param2,
|
||||
hpx::shared_future<void *> param3,
|
||||
hpx::shared_future<void *> param4,
|
||||
hpx::shared_future<void *> param5,
|
||||
hpx::shared_future<void *> param6)
|
||||
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
|
||||
std::vector<void *> params = {
|
||||
param0.get(), param1.get(), param2.get(), param3.get(),
|
||||
param4.get(), param5.get(), param6.get()};
|
||||
mlir::concretelang::dfr::OpaqueInputData oid(
|
||||
wfnname, params, param_sizes, param_types, output_sizes,
|
||||
output_types);
|
||||
output_types, ctx);
|
||||
return gcc_target->execute_task(oid);
|
||||
},
|
||||
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
|
||||
@@ -301,21 +301,21 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
case 8:
|
||||
oodf = std::move(hpx::dataflow(
|
||||
[wfnname, param_sizes, param_types, output_sizes, output_types,
|
||||
gcc_target](hpx::shared_future<void *> param0,
|
||||
hpx::shared_future<void *> param1,
|
||||
hpx::shared_future<void *> param2,
|
||||
hpx::shared_future<void *> param3,
|
||||
hpx::shared_future<void *> param4,
|
||||
hpx::shared_future<void *> param5,
|
||||
hpx::shared_future<void *> param6,
|
||||
hpx::shared_future<void *> param7)
|
||||
gcc_target, ctx](hpx::shared_future<void *> param0,
|
||||
hpx::shared_future<void *> param1,
|
||||
hpx::shared_future<void *> param2,
|
||||
hpx::shared_future<void *> param3,
|
||||
hpx::shared_future<void *> param4,
|
||||
hpx::shared_future<void *> param5,
|
||||
hpx::shared_future<void *> param6,
|
||||
hpx::shared_future<void *> param7)
|
||||
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
|
||||
std::vector<void *> params = {
|
||||
param0.get(), param1.get(), param2.get(), param3.get(),
|
||||
param4.get(), param5.get(), param6.get(), param7.get()};
|
||||
mlir::concretelang::dfr::OpaqueInputData oid(
|
||||
wfnname, params, param_sizes, param_types, output_sizes,
|
||||
output_types);
|
||||
output_types, ctx);
|
||||
return gcc_target->execute_task(oid);
|
||||
},
|
||||
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
|
||||
@@ -331,15 +331,15 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
case 9:
|
||||
oodf = std::move(hpx::dataflow(
|
||||
[wfnname, param_sizes, param_types, output_sizes, output_types,
|
||||
gcc_target](hpx::shared_future<void *> param0,
|
||||
hpx::shared_future<void *> param1,
|
||||
hpx::shared_future<void *> param2,
|
||||
hpx::shared_future<void *> param3,
|
||||
hpx::shared_future<void *> param4,
|
||||
hpx::shared_future<void *> param5,
|
||||
hpx::shared_future<void *> param6,
|
||||
hpx::shared_future<void *> param7,
|
||||
hpx::shared_future<void *> param8)
|
||||
gcc_target, ctx](hpx::shared_future<void *> param0,
|
||||
hpx::shared_future<void *> param1,
|
||||
hpx::shared_future<void *> param2,
|
||||
hpx::shared_future<void *> param3,
|
||||
hpx::shared_future<void *> param4,
|
||||
hpx::shared_future<void *> param5,
|
||||
hpx::shared_future<void *> param6,
|
||||
hpx::shared_future<void *> param7,
|
||||
hpx::shared_future<void *> param8)
|
||||
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
|
||||
std::vector<void *> params = {
|
||||
param0.get(), param1.get(), param2.get(),
|
||||
@@ -347,7 +347,7 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
param6.get(), param7.get(), param8.get()};
|
||||
mlir::concretelang::dfr::OpaqueInputData oid(
|
||||
wfnname, params, param_sizes, param_types, output_sizes,
|
||||
output_types);
|
||||
output_types, ctx);
|
||||
return gcc_target->execute_task(oid);
|
||||
},
|
||||
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
|
||||
@@ -364,16 +364,16 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
case 10:
|
||||
oodf = std::move(hpx::dataflow(
|
||||
[wfnname, param_sizes, param_types, output_sizes, output_types,
|
||||
gcc_target](hpx::shared_future<void *> param0,
|
||||
hpx::shared_future<void *> param1,
|
||||
hpx::shared_future<void *> param2,
|
||||
hpx::shared_future<void *> param3,
|
||||
hpx::shared_future<void *> param4,
|
||||
hpx::shared_future<void *> param5,
|
||||
hpx::shared_future<void *> param6,
|
||||
hpx::shared_future<void *> param7,
|
||||
hpx::shared_future<void *> param8,
|
||||
hpx::shared_future<void *> param9)
|
||||
gcc_target, ctx](hpx::shared_future<void *> param0,
|
||||
hpx::shared_future<void *> param1,
|
||||
hpx::shared_future<void *> param2,
|
||||
hpx::shared_future<void *> param3,
|
||||
hpx::shared_future<void *> param4,
|
||||
hpx::shared_future<void *> param5,
|
||||
hpx::shared_future<void *> param6,
|
||||
hpx::shared_future<void *> param7,
|
||||
hpx::shared_future<void *> param8,
|
||||
hpx::shared_future<void *> param9)
|
||||
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
|
||||
std::vector<void *> params = {
|
||||
param0.get(), param1.get(), param2.get(), param3.get(),
|
||||
@@ -381,7 +381,7 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
param8.get(), param9.get()};
|
||||
mlir::concretelang::dfr::OpaqueInputData oid(
|
||||
wfnname, params, param_sizes, param_types, output_sizes,
|
||||
output_types);
|
||||
output_types, ctx);
|
||||
return gcc_target->execute_task(oid);
|
||||
},
|
||||
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
|
||||
@@ -399,17 +399,17 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
case 11:
|
||||
oodf = std::move(hpx::dataflow(
|
||||
[wfnname, param_sizes, param_types, output_sizes, output_types,
|
||||
gcc_target](hpx::shared_future<void *> param0,
|
||||
hpx::shared_future<void *> param1,
|
||||
hpx::shared_future<void *> param2,
|
||||
hpx::shared_future<void *> param3,
|
||||
hpx::shared_future<void *> param4,
|
||||
hpx::shared_future<void *> param5,
|
||||
hpx::shared_future<void *> param6,
|
||||
hpx::shared_future<void *> param7,
|
||||
hpx::shared_future<void *> param8,
|
||||
hpx::shared_future<void *> param9,
|
||||
hpx::shared_future<void *> param10)
|
||||
gcc_target, ctx](hpx::shared_future<void *> param0,
|
||||
hpx::shared_future<void *> param1,
|
||||
hpx::shared_future<void *> param2,
|
||||
hpx::shared_future<void *> param3,
|
||||
hpx::shared_future<void *> param4,
|
||||
hpx::shared_future<void *> param5,
|
||||
hpx::shared_future<void *> param6,
|
||||
hpx::shared_future<void *> param7,
|
||||
hpx::shared_future<void *> param8,
|
||||
hpx::shared_future<void *> param9,
|
||||
hpx::shared_future<void *> param10)
|
||||
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
|
||||
std::vector<void *> params = {
|
||||
param0.get(), param1.get(), param2.get(), param3.get(),
|
||||
@@ -417,7 +417,7 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
param8.get(), param9.get(), param10.get()};
|
||||
mlir::concretelang::dfr::OpaqueInputData oid(
|
||||
wfnname, params, param_sizes, param_types, output_sizes,
|
||||
output_types);
|
||||
output_types, ctx);
|
||||
return gcc_target->execute_task(oid);
|
||||
},
|
||||
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
|
||||
@@ -436,18 +436,18 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
case 12:
|
||||
oodf = std::move(hpx::dataflow(
|
||||
[wfnname, param_sizes, param_types, output_sizes, output_types,
|
||||
gcc_target](hpx::shared_future<void *> param0,
|
||||
hpx::shared_future<void *> param1,
|
||||
hpx::shared_future<void *> param2,
|
||||
hpx::shared_future<void *> param3,
|
||||
hpx::shared_future<void *> param4,
|
||||
hpx::shared_future<void *> param5,
|
||||
hpx::shared_future<void *> param6,
|
||||
hpx::shared_future<void *> param7,
|
||||
hpx::shared_future<void *> param8,
|
||||
hpx::shared_future<void *> param9,
|
||||
hpx::shared_future<void *> param10,
|
||||
hpx::shared_future<void *> param11)
|
||||
gcc_target, ctx](hpx::shared_future<void *> param0,
|
||||
hpx::shared_future<void *> param1,
|
||||
hpx::shared_future<void *> param2,
|
||||
hpx::shared_future<void *> param3,
|
||||
hpx::shared_future<void *> param4,
|
||||
hpx::shared_future<void *> param5,
|
||||
hpx::shared_future<void *> param6,
|
||||
hpx::shared_future<void *> param7,
|
||||
hpx::shared_future<void *> param8,
|
||||
hpx::shared_future<void *> param9,
|
||||
hpx::shared_future<void *> param10,
|
||||
hpx::shared_future<void *> param11)
|
||||
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
|
||||
std::vector<void *> params = {
|
||||
param0.get(), param1.get(), param2.get(), param3.get(),
|
||||
@@ -455,7 +455,7 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
param8.get(), param9.get(), param10.get(), param11.get()};
|
||||
mlir::concretelang::dfr::OpaqueInputData oid(
|
||||
wfnname, params, param_sizes, param_types, output_sizes,
|
||||
output_types);
|
||||
output_types, ctx);
|
||||
return gcc_target->execute_task(oid);
|
||||
},
|
||||
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
|
||||
@@ -475,19 +475,19 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
case 13:
|
||||
oodf = std::move(hpx::dataflow(
|
||||
[wfnname, param_sizes, param_types, output_sizes, output_types,
|
||||
gcc_target](hpx::shared_future<void *> param0,
|
||||
hpx::shared_future<void *> param1,
|
||||
hpx::shared_future<void *> param2,
|
||||
hpx::shared_future<void *> param3,
|
||||
hpx::shared_future<void *> param4,
|
||||
hpx::shared_future<void *> param5,
|
||||
hpx::shared_future<void *> param6,
|
||||
hpx::shared_future<void *> param7,
|
||||
hpx::shared_future<void *> param8,
|
||||
hpx::shared_future<void *> param9,
|
||||
hpx::shared_future<void *> param10,
|
||||
hpx::shared_future<void *> param11,
|
||||
hpx::shared_future<void *> param12)
|
||||
gcc_target, ctx](hpx::shared_future<void *> param0,
|
||||
hpx::shared_future<void *> param1,
|
||||
hpx::shared_future<void *> param2,
|
||||
hpx::shared_future<void *> param3,
|
||||
hpx::shared_future<void *> param4,
|
||||
hpx::shared_future<void *> param5,
|
||||
hpx::shared_future<void *> param6,
|
||||
hpx::shared_future<void *> param7,
|
||||
hpx::shared_future<void *> param8,
|
||||
hpx::shared_future<void *> param9,
|
||||
hpx::shared_future<void *> param10,
|
||||
hpx::shared_future<void *> param11,
|
||||
hpx::shared_future<void *> param12)
|
||||
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
|
||||
std::vector<void *> params = {
|
||||
param0.get(), param1.get(), param2.get(), param3.get(),
|
||||
@@ -496,7 +496,7 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
param12.get()};
|
||||
mlir::concretelang::dfr::OpaqueInputData oid(
|
||||
wfnname, params, param_sizes, param_types, output_sizes,
|
||||
output_types);
|
||||
output_types, ctx);
|
||||
return gcc_target->execute_task(oid);
|
||||
},
|
||||
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
|
||||
@@ -517,20 +517,20 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
case 14:
|
||||
oodf = std::move(hpx::dataflow(
|
||||
[wfnname, param_sizes, param_types, output_sizes, output_types,
|
||||
gcc_target](hpx::shared_future<void *> param0,
|
||||
hpx::shared_future<void *> param1,
|
||||
hpx::shared_future<void *> param2,
|
||||
hpx::shared_future<void *> param3,
|
||||
hpx::shared_future<void *> param4,
|
||||
hpx::shared_future<void *> param5,
|
||||
hpx::shared_future<void *> param6,
|
||||
hpx::shared_future<void *> param7,
|
||||
hpx::shared_future<void *> param8,
|
||||
hpx::shared_future<void *> param9,
|
||||
hpx::shared_future<void *> param10,
|
||||
hpx::shared_future<void *> param11,
|
||||
hpx::shared_future<void *> param12,
|
||||
hpx::shared_future<void *> param13)
|
||||
gcc_target, ctx](hpx::shared_future<void *> param0,
|
||||
hpx::shared_future<void *> param1,
|
||||
hpx::shared_future<void *> param2,
|
||||
hpx::shared_future<void *> param3,
|
||||
hpx::shared_future<void *> param4,
|
||||
hpx::shared_future<void *> param5,
|
||||
hpx::shared_future<void *> param6,
|
||||
hpx::shared_future<void *> param7,
|
||||
hpx::shared_future<void *> param8,
|
||||
hpx::shared_future<void *> param9,
|
||||
hpx::shared_future<void *> param10,
|
||||
hpx::shared_future<void *> param11,
|
||||
hpx::shared_future<void *> param12,
|
||||
hpx::shared_future<void *> param13)
|
||||
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
|
||||
std::vector<void *> params = {
|
||||
param0.get(), param1.get(), param2.get(), param3.get(),
|
||||
@@ -539,7 +539,7 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
param12.get(), param13.get()};
|
||||
mlir::concretelang::dfr::OpaqueInputData oid(
|
||||
wfnname, params, param_sizes, param_types, output_sizes,
|
||||
output_types);
|
||||
output_types, ctx);
|
||||
return gcc_target->execute_task(oid);
|
||||
},
|
||||
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
|
||||
@@ -561,21 +561,21 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
case 15:
|
||||
oodf = std::move(hpx::dataflow(
|
||||
[wfnname, param_sizes, param_types, output_sizes, output_types,
|
||||
gcc_target](hpx::shared_future<void *> param0,
|
||||
hpx::shared_future<void *> param1,
|
||||
hpx::shared_future<void *> param2,
|
||||
hpx::shared_future<void *> param3,
|
||||
hpx::shared_future<void *> param4,
|
||||
hpx::shared_future<void *> param5,
|
||||
hpx::shared_future<void *> param6,
|
||||
hpx::shared_future<void *> param7,
|
||||
hpx::shared_future<void *> param8,
|
||||
hpx::shared_future<void *> param9,
|
||||
hpx::shared_future<void *> param10,
|
||||
hpx::shared_future<void *> param11,
|
||||
hpx::shared_future<void *> param12,
|
||||
hpx::shared_future<void *> param13,
|
||||
hpx::shared_future<void *> param14)
|
||||
gcc_target, ctx](hpx::shared_future<void *> param0,
|
||||
hpx::shared_future<void *> param1,
|
||||
hpx::shared_future<void *> param2,
|
||||
hpx::shared_future<void *> param3,
|
||||
hpx::shared_future<void *> param4,
|
||||
hpx::shared_future<void *> param5,
|
||||
hpx::shared_future<void *> param6,
|
||||
hpx::shared_future<void *> param7,
|
||||
hpx::shared_future<void *> param8,
|
||||
hpx::shared_future<void *> param9,
|
||||
hpx::shared_future<void *> param10,
|
||||
hpx::shared_future<void *> param11,
|
||||
hpx::shared_future<void *> param12,
|
||||
hpx::shared_future<void *> param13,
|
||||
hpx::shared_future<void *> param14)
|
||||
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
|
||||
std::vector<void *> params = {
|
||||
param0.get(), param1.get(), param2.get(), param3.get(),
|
||||
@@ -584,7 +584,7 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
param12.get(), param13.get(), param14.get()};
|
||||
mlir::concretelang::dfr::OpaqueInputData oid(
|
||||
wfnname, params, param_sizes, param_types, output_sizes,
|
||||
output_types);
|
||||
output_types, ctx);
|
||||
return gcc_target->execute_task(oid);
|
||||
},
|
||||
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
|
||||
@@ -607,22 +607,22 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
case 16:
|
||||
oodf = std::move(hpx::dataflow(
|
||||
[wfnname, param_sizes, param_types, output_sizes, output_types,
|
||||
gcc_target](hpx::shared_future<void *> param0,
|
||||
hpx::shared_future<void *> param1,
|
||||
hpx::shared_future<void *> param2,
|
||||
hpx::shared_future<void *> param3,
|
||||
hpx::shared_future<void *> param4,
|
||||
hpx::shared_future<void *> param5,
|
||||
hpx::shared_future<void *> param6,
|
||||
hpx::shared_future<void *> param7,
|
||||
hpx::shared_future<void *> param8,
|
||||
hpx::shared_future<void *> param9,
|
||||
hpx::shared_future<void *> param10,
|
||||
hpx::shared_future<void *> param11,
|
||||
hpx::shared_future<void *> param12,
|
||||
hpx::shared_future<void *> param13,
|
||||
hpx::shared_future<void *> param14,
|
||||
hpx::shared_future<void *> param15)
|
||||
gcc_target, ctx](hpx::shared_future<void *> param0,
|
||||
hpx::shared_future<void *> param1,
|
||||
hpx::shared_future<void *> param2,
|
||||
hpx::shared_future<void *> param3,
|
||||
hpx::shared_future<void *> param4,
|
||||
hpx::shared_future<void *> param5,
|
||||
hpx::shared_future<void *> param6,
|
||||
hpx::shared_future<void *> param7,
|
||||
hpx::shared_future<void *> param8,
|
||||
hpx::shared_future<void *> param9,
|
||||
hpx::shared_future<void *> param10,
|
||||
hpx::shared_future<void *> param11,
|
||||
hpx::shared_future<void *> param12,
|
||||
hpx::shared_future<void *> param13,
|
||||
hpx::shared_future<void *> param14,
|
||||
hpx::shared_future<void *> param15)
|
||||
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
|
||||
std::vector<void *> params = {
|
||||
param0.get(), param1.get(), param2.get(), param3.get(),
|
||||
@@ -631,7 +631,7 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
param12.get(), param13.get(), param14.get(), param15.get()};
|
||||
mlir::concretelang::dfr::OpaqueInputData oid(
|
||||
wfnname, params, param_sizes, param_types, output_sizes,
|
||||
output_types);
|
||||
output_types, ctx);
|
||||
return gcc_target->execute_task(oid);
|
||||
},
|
||||
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
|
||||
@@ -655,23 +655,23 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
case 17:
|
||||
oodf = std::move(hpx::dataflow(
|
||||
[wfnname, param_sizes, param_types, output_sizes, output_types,
|
||||
gcc_target](hpx::shared_future<void *> param0,
|
||||
hpx::shared_future<void *> param1,
|
||||
hpx::shared_future<void *> param2,
|
||||
hpx::shared_future<void *> param3,
|
||||
hpx::shared_future<void *> param4,
|
||||
hpx::shared_future<void *> param5,
|
||||
hpx::shared_future<void *> param6,
|
||||
hpx::shared_future<void *> param7,
|
||||
hpx::shared_future<void *> param8,
|
||||
hpx::shared_future<void *> param9,
|
||||
hpx::shared_future<void *> param10,
|
||||
hpx::shared_future<void *> param11,
|
||||
hpx::shared_future<void *> param12,
|
||||
hpx::shared_future<void *> param13,
|
||||
hpx::shared_future<void *> param14,
|
||||
hpx::shared_future<void *> param15,
|
||||
hpx::shared_future<void *> param16)
|
||||
gcc_target, ctx](hpx::shared_future<void *> param0,
|
||||
hpx::shared_future<void *> param1,
|
||||
hpx::shared_future<void *> param2,
|
||||
hpx::shared_future<void *> param3,
|
||||
hpx::shared_future<void *> param4,
|
||||
hpx::shared_future<void *> param5,
|
||||
hpx::shared_future<void *> param6,
|
||||
hpx::shared_future<void *> param7,
|
||||
hpx::shared_future<void *> param8,
|
||||
hpx::shared_future<void *> param9,
|
||||
hpx::shared_future<void *> param10,
|
||||
hpx::shared_future<void *> param11,
|
||||
hpx::shared_future<void *> param12,
|
||||
hpx::shared_future<void *> param13,
|
||||
hpx::shared_future<void *> param14,
|
||||
hpx::shared_future<void *> param15,
|
||||
hpx::shared_future<void *> param16)
|
||||
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
|
||||
std::vector<void *> params = {
|
||||
param0.get(), param1.get(), param2.get(), param3.get(),
|
||||
@@ -681,7 +681,7 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
param16.get()};
|
||||
mlir::concretelang::dfr::OpaqueInputData oid(
|
||||
wfnname, params, param_sizes, param_types, output_sizes,
|
||||
output_types);
|
||||
output_types, ctx);
|
||||
return gcc_target->execute_task(oid);
|
||||
},
|
||||
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
|
||||
@@ -706,24 +706,24 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
case 18:
|
||||
oodf = std::move(hpx::dataflow(
|
||||
[wfnname, param_sizes, param_types, output_sizes, output_types,
|
||||
gcc_target](hpx::shared_future<void *> param0,
|
||||
hpx::shared_future<void *> param1,
|
||||
hpx::shared_future<void *> param2,
|
||||
hpx::shared_future<void *> param3,
|
||||
hpx::shared_future<void *> param4,
|
||||
hpx::shared_future<void *> param5,
|
||||
hpx::shared_future<void *> param6,
|
||||
hpx::shared_future<void *> param7,
|
||||
hpx::shared_future<void *> param8,
|
||||
hpx::shared_future<void *> param9,
|
||||
hpx::shared_future<void *> param10,
|
||||
hpx::shared_future<void *> param11,
|
||||
hpx::shared_future<void *> param12,
|
||||
hpx::shared_future<void *> param13,
|
||||
hpx::shared_future<void *> param14,
|
||||
hpx::shared_future<void *> param15,
|
||||
hpx::shared_future<void *> param16,
|
||||
hpx::shared_future<void *> param17)
|
||||
gcc_target, ctx](hpx::shared_future<void *> param0,
|
||||
hpx::shared_future<void *> param1,
|
||||
hpx::shared_future<void *> param2,
|
||||
hpx::shared_future<void *> param3,
|
||||
hpx::shared_future<void *> param4,
|
||||
hpx::shared_future<void *> param5,
|
||||
hpx::shared_future<void *> param6,
|
||||
hpx::shared_future<void *> param7,
|
||||
hpx::shared_future<void *> param8,
|
||||
hpx::shared_future<void *> param9,
|
||||
hpx::shared_future<void *> param10,
|
||||
hpx::shared_future<void *> param11,
|
||||
hpx::shared_future<void *> param12,
|
||||
hpx::shared_future<void *> param13,
|
||||
hpx::shared_future<void *> param14,
|
||||
hpx::shared_future<void *> param15,
|
||||
hpx::shared_future<void *> param16,
|
||||
hpx::shared_future<void *> param17)
|
||||
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
|
||||
std::vector<void *> params = {
|
||||
param0.get(), param1.get(), param2.get(), param3.get(),
|
||||
@@ -733,7 +733,7 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
param16.get(), param17.get()};
|
||||
mlir::concretelang::dfr::OpaqueInputData oid(
|
||||
wfnname, params, param_sizes, param_types, output_sizes,
|
||||
output_types);
|
||||
output_types, ctx);
|
||||
return gcc_target->execute_task(oid);
|
||||
},
|
||||
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
|
||||
@@ -759,25 +759,25 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
case 19:
|
||||
oodf = std::move(hpx::dataflow(
|
||||
[wfnname, param_sizes, param_types, output_sizes, output_types,
|
||||
gcc_target](hpx::shared_future<void *> param0,
|
||||
hpx::shared_future<void *> param1,
|
||||
hpx::shared_future<void *> param2,
|
||||
hpx::shared_future<void *> param3,
|
||||
hpx::shared_future<void *> param4,
|
||||
hpx::shared_future<void *> param5,
|
||||
hpx::shared_future<void *> param6,
|
||||
hpx::shared_future<void *> param7,
|
||||
hpx::shared_future<void *> param8,
|
||||
hpx::shared_future<void *> param9,
|
||||
hpx::shared_future<void *> param10,
|
||||
hpx::shared_future<void *> param11,
|
||||
hpx::shared_future<void *> param12,
|
||||
hpx::shared_future<void *> param13,
|
||||
hpx::shared_future<void *> param14,
|
||||
hpx::shared_future<void *> param15,
|
||||
hpx::shared_future<void *> param16,
|
||||
hpx::shared_future<void *> param17,
|
||||
hpx::shared_future<void *> param18)
|
||||
gcc_target, ctx](hpx::shared_future<void *> param0,
|
||||
hpx::shared_future<void *> param1,
|
||||
hpx::shared_future<void *> param2,
|
||||
hpx::shared_future<void *> param3,
|
||||
hpx::shared_future<void *> param4,
|
||||
hpx::shared_future<void *> param5,
|
||||
hpx::shared_future<void *> param6,
|
||||
hpx::shared_future<void *> param7,
|
||||
hpx::shared_future<void *> param8,
|
||||
hpx::shared_future<void *> param9,
|
||||
hpx::shared_future<void *> param10,
|
||||
hpx::shared_future<void *> param11,
|
||||
hpx::shared_future<void *> param12,
|
||||
hpx::shared_future<void *> param13,
|
||||
hpx::shared_future<void *> param14,
|
||||
hpx::shared_future<void *> param15,
|
||||
hpx::shared_future<void *> param16,
|
||||
hpx::shared_future<void *> param17,
|
||||
hpx::shared_future<void *> param18)
|
||||
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
|
||||
std::vector<void *> params = {
|
||||
param0.get(), param1.get(), param2.get(), param3.get(),
|
||||
@@ -787,7 +787,7 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
param16.get(), param17.get(), param18.get()};
|
||||
mlir::concretelang::dfr::OpaqueInputData oid(
|
||||
wfnname, params, param_sizes, param_types, output_sizes,
|
||||
output_types);
|
||||
output_types, ctx);
|
||||
return gcc_target->execute_task(oid);
|
||||
},
|
||||
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
|
||||
@@ -814,26 +814,26 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
case 20:
|
||||
oodf = std::move(hpx::dataflow(
|
||||
[wfnname, param_sizes, param_types, output_sizes, output_types,
|
||||
gcc_target](hpx::shared_future<void *> param0,
|
||||
hpx::shared_future<void *> param1,
|
||||
hpx::shared_future<void *> param2,
|
||||
hpx::shared_future<void *> param3,
|
||||
hpx::shared_future<void *> param4,
|
||||
hpx::shared_future<void *> param5,
|
||||
hpx::shared_future<void *> param6,
|
||||
hpx::shared_future<void *> param7,
|
||||
hpx::shared_future<void *> param8,
|
||||
hpx::shared_future<void *> param9,
|
||||
hpx::shared_future<void *> param10,
|
||||
hpx::shared_future<void *> param11,
|
||||
hpx::shared_future<void *> param12,
|
||||
hpx::shared_future<void *> param13,
|
||||
hpx::shared_future<void *> param14,
|
||||
hpx::shared_future<void *> param15,
|
||||
hpx::shared_future<void *> param16,
|
||||
hpx::shared_future<void *> param17,
|
||||
hpx::shared_future<void *> param18,
|
||||
hpx::shared_future<void *> param19)
|
||||
gcc_target, ctx](hpx::shared_future<void *> param0,
|
||||
hpx::shared_future<void *> param1,
|
||||
hpx::shared_future<void *> param2,
|
||||
hpx::shared_future<void *> param3,
|
||||
hpx::shared_future<void *> param4,
|
||||
hpx::shared_future<void *> param5,
|
||||
hpx::shared_future<void *> param6,
|
||||
hpx::shared_future<void *> param7,
|
||||
hpx::shared_future<void *> param8,
|
||||
hpx::shared_future<void *> param9,
|
||||
hpx::shared_future<void *> param10,
|
||||
hpx::shared_future<void *> param11,
|
||||
hpx::shared_future<void *> param12,
|
||||
hpx::shared_future<void *> param13,
|
||||
hpx::shared_future<void *> param14,
|
||||
hpx::shared_future<void *> param15,
|
||||
hpx::shared_future<void *> param16,
|
||||
hpx::shared_future<void *> param17,
|
||||
hpx::shared_future<void *> param18,
|
||||
hpx::shared_future<void *> param19)
|
||||
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
|
||||
std::vector<void *> params = {
|
||||
param0.get(), param1.get(), param2.get(), param3.get(),
|
||||
@@ -843,7 +843,7 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
param16.get(), param17.get(), param18.get(), param19.get()};
|
||||
mlir::concretelang::dfr::OpaqueInputData oid(
|
||||
wfnname, params, param_sizes, param_types, output_sizes,
|
||||
output_types);
|
||||
output_types, ctx);
|
||||
return gcc_target->execute_task(oid);
|
||||
},
|
||||
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
|
||||
@@ -969,6 +969,7 @@ void _dfr_set_use_omp(bool use_omp) {
|
||||
bool _dfr_is_jit() { return mlir::concretelang::dfr::is_jit_p; }
|
||||
bool _dfr_is_root_node() { return mlir::concretelang::dfr::is_root_node_p; }
|
||||
bool _dfr_use_omp() { return mlir::concretelang::dfr::use_omp_p; }
|
||||
bool _dfr_is_distributed() { return num_nodes > 1; }
|
||||
} // namespace dfr
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
@@ -1002,6 +1003,7 @@ static inline void _dfr_stop_impl() {
|
||||
}
|
||||
|
||||
static inline void _dfr_start_impl(int argc, char *argv[]) {
|
||||
BEGIN_TIME(&mlir::concretelang::dfr::init_timer);
|
||||
mlir::concretelang::dfr::dl_handle = dlopen(nullptr, RTLD_NOW);
|
||||
|
||||
// If OpenMP is to be used, we need to force its initialization
|
||||
@@ -1126,15 +1128,15 @@ static inline void _dfr_start_impl(int argc, char *argv[]) {
|
||||
mlir::concretelang::dfr::num_nodes)
|
||||
.get();
|
||||
}
|
||||
END_TIME(&mlir::concretelang::dfr::init_timer, "Initialization");
|
||||
}
|
||||
|
||||
/* Start/stop functions to be called from within user code (or during
|
||||
JIT invocation). These serve to pause/resume the runtime
|
||||
scheduler and to clean up used resources. */
|
||||
void _dfr_start(int64_t use_dfr_p) {
|
||||
void _dfr_start(int64_t use_dfr_p, void *ctx) {
|
||||
BEGIN_TIME(&mlir::concretelang::dfr::whole_timer);
|
||||
if (use_dfr_p) {
|
||||
BEGIN_TIME(&mlir::concretelang::dfr::init_timer);
|
||||
// The first invocation will initialise the runtime. As each call to
|
||||
// _dfr_start is matched with _dfr_stop, if this is not hte first,
|
||||
// we need to resume the HPX runtime.
|
||||
@@ -1146,16 +1148,11 @@ void _dfr_start(int64_t use_dfr_p) {
|
||||
if (mlir::concretelang::dfr::init_guard.compare_exchange_strong(
|
||||
expected, mlir::concretelang::dfr::active))
|
||||
_dfr_start_impl(0, nullptr);
|
||||
END_TIME(&mlir::concretelang::dfr::init_timer, "Initialization");
|
||||
|
||||
assert(mlir::concretelang::dfr::init_guard ==
|
||||
mlir::concretelang::dfr::active &&
|
||||
"DFR runtime failed to initialise");
|
||||
|
||||
if (use_dfr_p == 1) {
|
||||
BEGIN_TIME(&mlir::concretelang::dfr::compute_timer);
|
||||
}
|
||||
|
||||
// If this is not the root node in a non-JIT execution, then this
|
||||
// node should only run the scheduler for any incoming work until
|
||||
// termination is flagged. If this is JIT, we need to run the
|
||||
@@ -1164,28 +1161,24 @@ void _dfr_start(int64_t use_dfr_p) {
|
||||
!mlir::concretelang::dfr::_dfr_is_jit())
|
||||
_dfr_stop_impl();
|
||||
}
|
||||
}
|
||||
|
||||
// Startup entry point when a RuntimeContext is used
|
||||
void _dfr_start_c(int64_t use_dfr_p, void *ctx) {
|
||||
_dfr_start(2);
|
||||
// If DFR is used and a runtime context is needed, and execution is
|
||||
// distributed, then broadcast from root to all compute nodes.
|
||||
if (use_dfr_p && (mlir::concretelang::dfr::num_nodes > 1) &&
|
||||
(ctx || !mlir::concretelang::dfr::_dfr_is_root_node())) {
|
||||
BEGIN_TIME(&mlir::concretelang::dfr::broadcast_timer);
|
||||
new mlir::concretelang::dfr::RuntimeContextManager();
|
||||
mlir::concretelang::dfr::_dfr_node_level_runtime_context_manager
|
||||
->setContext(ctx);
|
||||
|
||||
if (use_dfr_p) {
|
||||
if (mlir::concretelang::dfr::num_nodes > 1) {
|
||||
BEGIN_TIME(&mlir::concretelang::dfr::broadcast_timer);
|
||||
new mlir::concretelang::dfr::RuntimeContextManager();
|
||||
mlir::concretelang::dfr::_dfr_node_level_runtime_context_manager
|
||||
->setContext(ctx);
|
||||
|
||||
// If this is not JIT, then the remote nodes never reach _dfr_stop,
|
||||
// so root should not instantiate this barrier.
|
||||
if (mlir::concretelang::dfr::_dfr_is_root_node() &&
|
||||
mlir::concretelang::dfr::_dfr_is_jit())
|
||||
mlir::concretelang::dfr::_dfr_startup_barrier->wait();
|
||||
END_TIME(&mlir::concretelang::dfr::broadcast_timer, "Key broadcasting");
|
||||
}
|
||||
BEGIN_TIME(&mlir::concretelang::dfr::compute_timer);
|
||||
// If this is not JIT, then the remote nodes never reach _dfr_stop,
|
||||
// so root should not instantiate this barrier.
|
||||
if (mlir::concretelang::dfr::_dfr_is_root_node() &&
|
||||
mlir::concretelang::dfr::_dfr_is_jit())
|
||||
mlir::concretelang::dfr::_dfr_startup_barrier->wait();
|
||||
END_TIME(&mlir::concretelang::dfr::broadcast_timer, "Key broadcasting");
|
||||
}
|
||||
BEGIN_TIME(&mlir::concretelang::dfr::compute_timer);
|
||||
}
|
||||
|
||||
// This function cannot be used to terminate the runtime as it is
|
||||
@@ -1216,8 +1209,8 @@ void _dfr_stop(int64_t use_dfr_p) {
|
||||
mlir::concretelang::dfr::_dfr_node_level_runtime_context_manager
|
||||
->clearContext();
|
||||
}
|
||||
END_TIME(&mlir::concretelang::dfr::compute_timer, "Compute");
|
||||
}
|
||||
END_TIME(&mlir::concretelang::dfr::compute_timer, "Compute");
|
||||
END_TIME(&mlir::concretelang::dfr::whole_timer, "Total execution");
|
||||
}
|
||||
|
||||
@@ -1298,6 +1291,7 @@ namespace dfr {
|
||||
namespace {
|
||||
static bool is_jit_p = false;
|
||||
static bool use_omp_p = false;
|
||||
static size_t num_nodes = 1;
|
||||
static struct timespec compute_timer;
|
||||
} // namespace
|
||||
|
||||
@@ -1307,15 +1301,15 @@ void _dfr_set_use_omp(bool use_omp) { use_omp_p = use_omp; }
|
||||
bool _dfr_is_jit() { return is_jit_p; }
|
||||
bool _dfr_is_root_node() { return true; }
|
||||
bool _dfr_use_omp() { return use_omp_p; }
|
||||
bool _dfr_is_distributed() { return num_nodes > 1; }
|
||||
|
||||
} // namespace dfr
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
void _dfr_start(int64_t use_dfr_p) {
|
||||
void _dfr_start(int64_t use_dfr_p, void *ctx) {
|
||||
BEGIN_TIME(&mlir::concretelang::dfr::compute_timer);
|
||||
}
|
||||
void _dfr_start_c(int64_t use_dfr_p, void *ctx) { _dfr_start(2); }
|
||||
void _dfr_stop(int64_t use_dfr_p) {
|
||||
END_TIME(&mlir::concretelang::dfr::compute_timer, "Compute");
|
||||
}
|
||||
|
||||
@@ -153,6 +153,8 @@ mlir::LogicalResult autopar(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
|
||||
addPotentiallyNestedPass(
|
||||
pm, mlir::concretelang::createBuildDataflowTaskGraphPass(), enablePass);
|
||||
addPotentiallyNestedPass(
|
||||
pm, mlir::concretelang::createLowerDataflowTasksPass(), enablePass);
|
||||
|
||||
return pm.run(module.getOperation());
|
||||
}
|
||||
@@ -294,6 +296,9 @@ lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
mlir::bufferization::createOneShotBufferizePass(bufferizationOptions);
|
||||
|
||||
addPotentiallyNestedPass(pm, std::move(comprBuffPass), enablePass);
|
||||
addPotentiallyNestedPass(
|
||||
pm, mlir::concretelang::createBufferizeDataflowTaskOpsPass(), enablePass);
|
||||
|
||||
if (parallelizeLoops) {
|
||||
addPotentiallyNestedPass(pm, mlir::concretelang::createForLoopToParallel(),
|
||||
enablePass);
|
||||
@@ -305,14 +310,16 @@ lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
// Lower affine
|
||||
addPotentiallyNestedPass(pm, mlir::createLowerAffinePass(), enablePass);
|
||||
|
||||
// Lower Dataflow tasks to DRF
|
||||
// Finalize the lowering of RT/DFR which includes:
|
||||
// - adding type and typesize information for dependences
|
||||
// - issue _dfr_start and _dfr_stop calls to start/stop the runtime
|
||||
// - remove deallocation calls for buffers managed through refcounting
|
||||
addPotentiallyNestedPass(
|
||||
pm, mlir::concretelang::createFixupDataflowTaskOpsPass(), enablePass);
|
||||
addPotentiallyNestedPass(
|
||||
pm, mlir::concretelang::createLowerDataflowTasksPass(), enablePass);
|
||||
// Use the buffer deallocation interface to insert future deallocation calls
|
||||
pm, mlir::concretelang::createFinalizeTaskCreationPass(), enablePass);
|
||||
addPotentiallyNestedPass(
|
||||
pm, mlir::bufferization::createBufferDeallocationPass(), enablePass);
|
||||
addPotentiallyNestedPass(pm, mlir::concretelang::createStartStopPass(),
|
||||
enablePass);
|
||||
addPotentiallyNestedPass(
|
||||
pm, mlir::concretelang::createFixupBufferDeallocationPass(), enablePass);
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
|
||||
std::vector<uint64_t> distributed_results;
|
||||
|
||||
TEST(DISABLED_Distributed, nn_med_nested) {
|
||||
TEST(Distributed, nn_med_nested) {
|
||||
checkedJit(lambda, R"XXX(
|
||||
func.func @main(%arg0: tensor<200x4x!FHE.eint<4>>) -> tensor<200x8x!FHE.eint<4>> {
|
||||
%cst = arith.constant dense<"0x01010100010100000001010101000101010101010101010001000101000001010001010100000101000001000001010001000001010100010001000000010100010001010001000001000101010101000100010001000000000100010001000101000001000101010100010001000000000101000100000000000001000100000100000100000001010000010001000101000100010001000100000100000100010101010000000000000000010001010000000100000100010100000100000000010001000101000100000000000101010101000101010101010100010100010100000000000101010100000100010100000001000101000000010101000101000100000101010100010101010000010101010100010000000000000001010101000100010101000001010001010000010001010101000000000000000001000001000000010100000100000101010100010001000000000000010100010101000000010100000100010001010001000000000100010001000101010100010100000001010100010101010100010100010001000001000000000101000101010001000100000101010100000101010100000100010101000100000101000101010100010001000101010100010001010001010000010000010001010000000001000101010001000000000101000000010000010100010001000001000001010101000100010001010100000101000000010001000000000101000101000000010000000001000101010100010001000000000001010000010001000001010101000101010101010100000000000001000100000100000001000000010101010101000000000101010101000100000101000100000000000001000100000101000101010100010000000101000000000100000100000101010000010100000000010000000000010001000100000101010001010101000000000000010000010101010001000000010001010001010000000000000101000000010101010101000001010101000001000001010100000000010001010100000100000101000101010100010001010001000001000100000101000100010100000100010000000101000000010000010001010101010000000101000000010101000001010100000100010001000000000001010000000100010000000000000000000000000001010101010101010101000001010101000001010100000001000101010101010000010101000101010100010101010000010101010100000100000000000101010000000000010101010000000001000000010100000100000001000101010000000001000001000001010001010000010001000101010001010001010101000100010000000100000100010101000000000101010101010001000100000000000101010000010101000001010001010000000001010100000101000001010000000001010101000100010000010101000000000001000101000001010101000101000001000001000000010100010001000101010100010001010000000101000000010001000001000100000101010001000001000001000101010000010001000001000101000000000000000101010000010000000101010100010100010001010101010000000000010001000101010000000001010100000000010001010100010001000001000101000000010100010000010000010001010100010000010001010100010000010100010101010001000100010100010101000100000101010100000100010100000100000000010101000000010001000001010000000101000100000100010101000000010100000101000001010001010100010000000101010000000001010001000000010100010101010001000100010001000001010101000000010001000100000100010101000000000000010100010000000100000000010100010000000100000101010000010101000100010000010100000001000100000000000100000001010101010101000100010001000000010101010100000001000001000001010001000101010100000001010001010100010101000101000000010001010100010101000100000101000101000001000001000001000101010100010001010000000100000101010100000001000000000000010101000100010001000001000001000000000000010100000100000001"> : tensor<200x8xi5>
|
||||
@@ -106,7 +106,7 @@ func.func @main(%arg0: tensor<200x4x!FHE.eint<4>>) -> tensor<200x8x!FHE.eint<4>>
|
||||
ASSERT_EXPECTED_FAILURE(lambda.operator()<std::vector<uint64_t>>());
|
||||
}
|
||||
|
||||
TEST(DISABLED_Distributed, nn_med_sequential) {
|
||||
TEST(Distributed, nn_med_sequential) {
|
||||
if (mlir::concretelang::dfr::_dfr_is_root_node()) {
|
||||
checkedJit(lambda, R"XXX(
|
||||
func.func @main(%arg0: tensor<200x4x!FHE.eint<4>>) -> tensor<200x8x!FHE.eint<4>> {
|
||||
|
||||
Reference in New Issue
Block a user