feat(compiler): move the lowering of dataflow tasks to RT dialect before bufferization.

This commit is contained in:
Antoniu Pop
2022-08-19 10:52:49 +01:00
committed by Antoniu Pop
parent 26901a32da
commit 2cf80e76eb
20 changed files with 1453 additions and 1026 deletions

View File

@@ -18,7 +18,6 @@
#include <mlir/Dialect/Arithmetic/IR/Arithmetic.h>
#include <mlir/Dialect/Func/IR/FuncOps.h>
#include <mlir/Dialect/MemRef/IR/MemRef.h>
#include <mlir/IR/Attributes.h>
#include <mlir/IR/BlockAndValueMapping.h>
#include <mlir/IR/Builders.h>
@@ -61,9 +60,8 @@ static bool isAggregatingBeneficiary(Operation *op) {
return isa<FHE::ZeroEintOp, FHE::ZeroTensorOp, FHE::AddEintIntOp,
FHE::AddEintOp, FHE::SubIntEintOp, FHE::SubEintIntOp,
FHE::MulEintIntOp, FHE::SubEintOp, FHE::NegEintOp,
FHELinalg::FromElementOp, arith::ConstantOp, memref::DimOp,
arith::SelectOp, mlir::arith::CmpIOp, memref::GetGlobalOp,
memref::CastOp>(op);
FHELinalg::FromElementOp, arith::ConstantOp, arith::SelectOp,
mlir::arith::CmpIOp>(op);
}
static bool
@@ -95,87 +93,6 @@ aggregateBeneficiaryOps(Operation *op, SetVector<Operation *> &beneficiaryOps,
return true;
}
static bool isFunctionCallName(OpOperand *use, StringRef name) {
func::CallOp call = dyn_cast_or_null<mlir::func::CallOp>(use->getOwner());
if (!call)
return false;
SymbolRefAttr sym = call.getCallableForCallee().dyn_cast<SymbolRefAttr>();
if (!sym)
return false;
func::FuncOp called = dyn_cast_or_null<func::FuncOp>(
SymbolTable::lookupNearestSymbolFrom(call, sym));
if (!called)
return false;
return called.getName() == name;
}
static void getAliasedUses(Value val, DenseSet<OpOperand *> &aliasedUses) {
for (auto &use : val.getUses()) {
aliasedUses.insert(&use);
if (dyn_cast<ViewLikeOpInterface>(use.getOwner()))
getAliasedUses(use.getOwner()->getResult(0), aliasedUses);
}
}
static bool aggregateOutputMemrefAllocations(
Operation *op, SetVector<Operation *> &beneficiaryOps,
llvm::SmallPtrSetImpl<Value> &availableValues, RT::DataflowTaskOp taskOp) {
if (beneficiaryOps.count(op))
return true;
if (!isa<mlir::memref::AllocOp>(op))
return false;
Value val = op->getResults().front();
DenseSet<OpOperand *> aliasedUses;
getAliasedUses(val, aliasedUses);
// Helper function checking if a memref use writes to memory
auto hasMemoryWriteEffect = [&](OpOperand *use) {
// Call ops targeting concrete-ffi do not have memory effects
// interface, so handle apart.
// TODO: this could be handled better in BConcrete or higher.
if (isFunctionCallName(use, "memref_expand_lut_in_trivial_glwe_ct_u64") ||
isFunctionCallName(use, "memref_add_lwe_ciphertexts_u64") ||
isFunctionCallName(use, "memref_add_plaintext_lwe_ciphertext_u64") ||
isFunctionCallName(use, "memref_mul_cleartext_lwe_ciphertext_u64") ||
isFunctionCallName(use, "memref_negate_lwe_ciphertext_u64") ||
isFunctionCallName(use, "memref_keyswitch_lwe_u64") ||
isFunctionCallName(use, "memref_bootstrap_lwe_u64"))
if (use->getOwner()->getOperand(0) == use->get())
return true;
if (isFunctionCallName(use, "memref_copy_one_rank"))
if (use->getOwner()->getOperand(1) == use->get())
return true;
// Otherwise we rely on the memory effect interface
auto effectInterface = dyn_cast<MemoryEffectOpInterface>(use->getOwner());
if (!effectInterface)
return false;
SmallVector<MemoryEffects::EffectInstance, 2> effects;
effectInterface.getEffects(effects);
for (auto eff : effects)
if (isa<MemoryEffects::Write>(eff.getEffect()) &&
eff.getValue() == use->get())
return true;
return false;
};
// We need to check if this allocated memref is written in this task.
// TODO: for now we'll assume that we don't do partial writes or read/writes.
for (auto use : aliasedUses)
if (hasMemoryWriteEffect(use) &&
use->getOwner()->getParentOfType<RT::DataflowTaskOp>() == taskOp) {
// We will sink the operation, mark its results as now available.
beneficiaryOps.insert(op);
for (Value result : op->getResults())
availableValues.insert(result);
return true;
}
return false;
}
LogicalResult coarsenDFTask(RT::DataflowTaskOp taskOp) {
Region &taskOpBody = taskOp.body();
@@ -191,8 +108,6 @@ LogicalResult coarsenDFTask(RT::DataflowTaskOp taskOp) {
if (!operandOp)
continue;
aggregateBeneficiaryOps(operandOp, toBeSunk, availableValues);
aggregateOutputMemrefAllocations(operandOp, toBeSunk, availableValues,
taskOp);
}
// Insert operations so that the defs get cloned before uses.
@@ -283,45 +198,5 @@ std::unique_ptr<mlir::Pass> createBuildDataflowTaskGraphPass(bool debug) {
return std::make_unique<BuildDataflowTaskGraphPass>(debug);
}
namespace {
/// For documentation see Autopar.td
struct FixupDataflowTaskOpsPass
: public FixupDataflowTaskOpsBase<FixupDataflowTaskOpsPass> {
void runOnOperation() override {
auto module = getOperation();
module->walk([](RT::DataflowTaskOp op) {
assert(!failed(coarsenDFTask(op)) &&
"Failing to sink operations into DFT");
});
// Finally clear up any remaining alloc/dealloc ops that are
// meaningless
SetVector<Operation *> eraseOps;
module->walk([&](memref::AllocOp op) {
// If this memref.alloc's only use left is the
// dealloc, erase both.
if (op->hasOneUse() &&
isa<mlir::memref::DeallocOp>(op->use_begin()->getOwner())) {
eraseOps.insert(op->use_begin()->getOwner());
eraseOps.insert(op);
}
});
for (auto op : eraseOps)
op->erase();
}
FixupDataflowTaskOpsPass(bool debug) : debug(debug){};
protected:
bool debug;
};
} // end anonymous namespace
std::unique_ptr<mlir::Pass> createFixupDataflowTaskOpsPass(bool debug) {
return std::make_unique<FixupDataflowTaskOpsPass>(debug);
}
} // end namespace concretelang
} // end namespace mlir