feat(dfr): add memory management for futures and associated data.

This commit is contained in:
Antoniu Pop
2022-06-08 21:58:33 +01:00
committed by Antoniu Pop
parent b405a2daf2
commit fbca52f4a0
14 changed files with 1047 additions and 348 deletions

View File

@@ -57,7 +57,7 @@ static bool isCandidateForTask(Operation *op) {
/// operations must not have side-effects and not be `isCandidateForTask`
static bool isSinkingBeneficiary(Operation *op) {
return isa<FHE::ZeroEintOp, arith::ConstantOp, memref::DimOp, arith::SelectOp,
mlir::arith::CmpIOp>(op);
mlir::arith::CmpIOp, mlir::memref::GetGlobalOp>(op);
}
static bool
@@ -90,6 +90,92 @@ extractBeneficiaryOps(Operation *op, SetVector<Value> existingDependencies,
return true;
}
static func::FuncOp getCalledFunction(CallOpInterface callOp) {
SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>();
if (!sym)
return nullptr;
return dyn_cast_or_null<func::FuncOp>(
SymbolTable::lookupNearestSymbolFrom(callOp, sym));
}
static void getAliasedUses(Value val, DenseSet<OpOperand *> &aliasedUses) {
for (auto &use : val.getUses()) {
aliasedUses.insert(&use);
if (isa<memref::CastOp, memref::ViewOp, memref::SubViewOp>(use.getOwner()))
getAliasedUses(use.getOwner()->getResult(0), aliasedUses);
}
}
static bool extractOutputMemrefAllocations(
Operation *op, SetVector<Value> existingDependencies,
SetVector<Operation *> &beneficiaryOps,
llvm::SmallPtrSetImpl<Value> &availableValues, RT::DataflowTaskOp taskOp) {
if (beneficiaryOps.count(op))
return true;
if (!isa<mlir::memref::AllocOp>(op))
return false;
Value val = op->getResults().front();
DenseSet<OpOperand *> aliasedUses;
getAliasedUses(val, aliasedUses);
// Helper function checking if a memref use writes to memory
auto hasMemoryWriteEffect = [&](OpOperand *use) {
// Call ops targeting concrete-ffi do not have memory effects
// interface, so handle apart.
// TODO: this could be handled better in BConcrete or higher.
if (isa<mlir::func::CallOp>(use->getOwner())) {
if (getCalledFunction(use->getOwner()).getName() ==
"memref_expand_lut_in_trivial_glwe_ct_u64" ||
getCalledFunction(use->getOwner()).getName() ==
"memref_add_lwe_ciphertexts_u64" ||
getCalledFunction(use->getOwner()).getName() ==
"memref_add_plaintext_lwe_ciphertext_u64" ||
getCalledFunction(use->getOwner()).getName() ==
"memref_mul_cleartext_lwe_ciphertext_u64" ||
getCalledFunction(use->getOwner()).getName() ==
"memref_negate_lwe_ciphertext_u64" ||
getCalledFunction(use->getOwner()).getName() ==
"memref_keyswitch_lwe_u64" ||
getCalledFunction(use->getOwner()).getName() ==
"memref_bootstrap_lwe_u64")
if (use->getOwner()->getOperand(0) == use->get())
return true;
if (getCalledFunction(use->getOwner()).getName() ==
"memref_copy_one_rank")
if (use->getOwner()->getOperand(1) == use->get())
return true;
}
// Otherwise we rely on the memory effect interface
auto effectInterface = dyn_cast<MemoryEffectOpInterface>(use->getOwner());
if (!effectInterface)
return false;
SmallVector<MemoryEffects::EffectInstance, 2> effects;
effectInterface.getEffects(effects);
for (auto eff : effects)
if (isa<MemoryEffects::Write>(eff.getEffect()) &&
eff.getValue() == use->get())
return true;
return false;
};
// We need to check if this allocated memref is written in this task.
// TODO: for now we'll assume that we don't do partial writes or read/writes.
for (auto use : aliasedUses)
if (hasMemoryWriteEffect(use) &&
use->getOwner()->getParentOfType<RT::DataflowTaskOp>() == taskOp) {
// We will sink the operation, mark its results as now available.
beneficiaryOps.insert(op);
for (Value result : op->getResults())
availableValues.insert(result);
return true;
}
return false;
}
LogicalResult sinkOperationsIntoDFTask(RT::DataflowTaskOp taskOp) {
Region &taskOpBody = taskOp.body();
@@ -104,6 +190,8 @@ LogicalResult sinkOperationsIntoDFTask(RT::DataflowTaskOp taskOp) {
if (!operandOp)
continue;
extractBeneficiaryOps(operandOp, sinkCandidates, toBeSunk, availableValues);
extractOutputMemrefAllocations(operandOp, sinkCandidates, toBeSunk,
availableValues, taskOp);
}
// Insert operations so that the defs get cloned before uses.