mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
feat(dfr): add memory management for futures and associated data.
This commit is contained in:
@@ -57,7 +57,7 @@ static bool isCandidateForTask(Operation *op) {
|
||||
/// operations must not have side-effects and not be `isCandidateForTask`
|
||||
static bool isSinkingBeneficiary(Operation *op) {
|
||||
return isa<FHE::ZeroEintOp, arith::ConstantOp, memref::DimOp, arith::SelectOp,
|
||||
mlir::arith::CmpIOp>(op);
|
||||
mlir::arith::CmpIOp, mlir::memref::GetGlobalOp>(op);
|
||||
}
|
||||
|
||||
static bool
|
||||
@@ -90,6 +90,92 @@ extractBeneficiaryOps(Operation *op, SetVector<Value> existingDependencies,
|
||||
return true;
|
||||
}
|
||||
|
||||
static func::FuncOp getCalledFunction(CallOpInterface callOp) {
|
||||
SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>();
|
||||
if (!sym)
|
||||
return nullptr;
|
||||
return dyn_cast_or_null<func::FuncOp>(
|
||||
SymbolTable::lookupNearestSymbolFrom(callOp, sym));
|
||||
}
|
||||
|
||||
static void getAliasedUses(Value val, DenseSet<OpOperand *> &aliasedUses) {
|
||||
for (auto &use : val.getUses()) {
|
||||
aliasedUses.insert(&use);
|
||||
if (isa<memref::CastOp, memref::ViewOp, memref::SubViewOp>(use.getOwner()))
|
||||
getAliasedUses(use.getOwner()->getResult(0), aliasedUses);
|
||||
}
|
||||
}
|
||||
|
||||
static bool extractOutputMemrefAllocations(
|
||||
Operation *op, SetVector<Value> existingDependencies,
|
||||
SetVector<Operation *> &beneficiaryOps,
|
||||
llvm::SmallPtrSetImpl<Value> &availableValues, RT::DataflowTaskOp taskOp) {
|
||||
if (beneficiaryOps.count(op))
|
||||
return true;
|
||||
|
||||
if (!isa<mlir::memref::AllocOp>(op))
|
||||
return false;
|
||||
|
||||
Value val = op->getResults().front();
|
||||
DenseSet<OpOperand *> aliasedUses;
|
||||
getAliasedUses(val, aliasedUses);
|
||||
|
||||
// Helper function checking if a memref use writes to memory
|
||||
auto hasMemoryWriteEffect = [&](OpOperand *use) {
|
||||
// Call ops targeting concrete-ffi do not have memory effects
|
||||
// interface, so handle apart.
|
||||
// TODO: this could be handled better in BConcrete or higher.
|
||||
if (isa<mlir::func::CallOp>(use->getOwner())) {
|
||||
if (getCalledFunction(use->getOwner()).getName() ==
|
||||
"memref_expand_lut_in_trivial_glwe_ct_u64" ||
|
||||
getCalledFunction(use->getOwner()).getName() ==
|
||||
"memref_add_lwe_ciphertexts_u64" ||
|
||||
getCalledFunction(use->getOwner()).getName() ==
|
||||
"memref_add_plaintext_lwe_ciphertext_u64" ||
|
||||
getCalledFunction(use->getOwner()).getName() ==
|
||||
"memref_mul_cleartext_lwe_ciphertext_u64" ||
|
||||
getCalledFunction(use->getOwner()).getName() ==
|
||||
"memref_negate_lwe_ciphertext_u64" ||
|
||||
getCalledFunction(use->getOwner()).getName() ==
|
||||
"memref_keyswitch_lwe_u64" ||
|
||||
getCalledFunction(use->getOwner()).getName() ==
|
||||
"memref_bootstrap_lwe_u64")
|
||||
if (use->getOwner()->getOperand(0) == use->get())
|
||||
return true;
|
||||
|
||||
if (getCalledFunction(use->getOwner()).getName() ==
|
||||
"memref_copy_one_rank")
|
||||
if (use->getOwner()->getOperand(1) == use->get())
|
||||
return true;
|
||||
}
|
||||
|
||||
// Otherwise we rely on the memory effect interface
|
||||
auto effectInterface = dyn_cast<MemoryEffectOpInterface>(use->getOwner());
|
||||
if (!effectInterface)
|
||||
return false;
|
||||
SmallVector<MemoryEffects::EffectInstance, 2> effects;
|
||||
effectInterface.getEffects(effects);
|
||||
for (auto eff : effects)
|
||||
if (isa<MemoryEffects::Write>(eff.getEffect()) &&
|
||||
eff.getValue() == use->get())
|
||||
return true;
|
||||
return false;
|
||||
};
|
||||
|
||||
// We need to check if this allocated memref is written in this task.
|
||||
// TODO: for now we'll assume that we don't do partial writes or read/writes.
|
||||
for (auto use : aliasedUses)
|
||||
if (hasMemoryWriteEffect(use) &&
|
||||
use->getOwner()->getParentOfType<RT::DataflowTaskOp>() == taskOp) {
|
||||
// We will sink the operation, mark its results as now available.
|
||||
beneficiaryOps.insert(op);
|
||||
for (Value result : op->getResults())
|
||||
availableValues.insert(result);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
LogicalResult sinkOperationsIntoDFTask(RT::DataflowTaskOp taskOp) {
|
||||
Region &taskOpBody = taskOp.body();
|
||||
|
||||
@@ -104,6 +190,8 @@ LogicalResult sinkOperationsIntoDFTask(RT::DataflowTaskOp taskOp) {
|
||||
if (!operandOp)
|
||||
continue;
|
||||
extractBeneficiaryOps(operandOp, sinkCandidates, toBeSunk, availableValues);
|
||||
extractOutputMemrefAllocations(operandOp, sinkCandidates, toBeSunk,
|
||||
availableValues, taskOp);
|
||||
}
|
||||
|
||||
// Insert operations so that the defs get cloned before uses.
|
||||
|
||||
Reference in New Issue
Block a user