mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
feat(compiler): move the lowering of dataflow tasks to RT dialect before bufferization.
This commit is contained in:
@@ -18,7 +18,6 @@
|
||||
|
||||
#include <mlir/Dialect/Arithmetic/IR/Arithmetic.h>
|
||||
#include <mlir/Dialect/Func/IR/FuncOps.h>
|
||||
#include <mlir/Dialect/MemRef/IR/MemRef.h>
|
||||
#include <mlir/IR/Attributes.h>
|
||||
#include <mlir/IR/BlockAndValueMapping.h>
|
||||
#include <mlir/IR/Builders.h>
|
||||
@@ -61,9 +60,8 @@ static bool isAggregatingBeneficiary(Operation *op) {
|
||||
return isa<FHE::ZeroEintOp, FHE::ZeroTensorOp, FHE::AddEintIntOp,
|
||||
FHE::AddEintOp, FHE::SubIntEintOp, FHE::SubEintIntOp,
|
||||
FHE::MulEintIntOp, FHE::SubEintOp, FHE::NegEintOp,
|
||||
FHELinalg::FromElementOp, arith::ConstantOp, memref::DimOp,
|
||||
arith::SelectOp, mlir::arith::CmpIOp, memref::GetGlobalOp,
|
||||
memref::CastOp>(op);
|
||||
FHELinalg::FromElementOp, arith::ConstantOp, arith::SelectOp,
|
||||
mlir::arith::CmpIOp>(op);
|
||||
}
|
||||
|
||||
static bool
|
||||
@@ -95,87 +93,6 @@ aggregateBeneficiaryOps(Operation *op, SetVector<Operation *> &beneficiaryOps,
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool isFunctionCallName(OpOperand *use, StringRef name) {
|
||||
func::CallOp call = dyn_cast_or_null<mlir::func::CallOp>(use->getOwner());
|
||||
if (!call)
|
||||
return false;
|
||||
SymbolRefAttr sym = call.getCallableForCallee().dyn_cast<SymbolRefAttr>();
|
||||
if (!sym)
|
||||
return false;
|
||||
func::FuncOp called = dyn_cast_or_null<func::FuncOp>(
|
||||
SymbolTable::lookupNearestSymbolFrom(call, sym));
|
||||
if (!called)
|
||||
return false;
|
||||
return called.getName() == name;
|
||||
}
|
||||
|
||||
static void getAliasedUses(Value val, DenseSet<OpOperand *> &aliasedUses) {
|
||||
for (auto &use : val.getUses()) {
|
||||
aliasedUses.insert(&use);
|
||||
if (dyn_cast<ViewLikeOpInterface>(use.getOwner()))
|
||||
getAliasedUses(use.getOwner()->getResult(0), aliasedUses);
|
||||
}
|
||||
}
|
||||
|
||||
static bool aggregateOutputMemrefAllocations(
|
||||
Operation *op, SetVector<Operation *> &beneficiaryOps,
|
||||
llvm::SmallPtrSetImpl<Value> &availableValues, RT::DataflowTaskOp taskOp) {
|
||||
if (beneficiaryOps.count(op))
|
||||
return true;
|
||||
|
||||
if (!isa<mlir::memref::AllocOp>(op))
|
||||
return false;
|
||||
|
||||
Value val = op->getResults().front();
|
||||
DenseSet<OpOperand *> aliasedUses;
|
||||
getAliasedUses(val, aliasedUses);
|
||||
|
||||
// Helper function checking if a memref use writes to memory
|
||||
auto hasMemoryWriteEffect = [&](OpOperand *use) {
|
||||
// Call ops targeting concrete-ffi do not have memory effects
|
||||
// interface, so handle apart.
|
||||
// TODO: this could be handled better in BConcrete or higher.
|
||||
if (isFunctionCallName(use, "memref_expand_lut_in_trivial_glwe_ct_u64") ||
|
||||
isFunctionCallName(use, "memref_add_lwe_ciphertexts_u64") ||
|
||||
isFunctionCallName(use, "memref_add_plaintext_lwe_ciphertext_u64") ||
|
||||
isFunctionCallName(use, "memref_mul_cleartext_lwe_ciphertext_u64") ||
|
||||
isFunctionCallName(use, "memref_negate_lwe_ciphertext_u64") ||
|
||||
isFunctionCallName(use, "memref_keyswitch_lwe_u64") ||
|
||||
isFunctionCallName(use, "memref_bootstrap_lwe_u64"))
|
||||
if (use->getOwner()->getOperand(0) == use->get())
|
||||
return true;
|
||||
|
||||
if (isFunctionCallName(use, "memref_copy_one_rank"))
|
||||
if (use->getOwner()->getOperand(1) == use->get())
|
||||
return true;
|
||||
|
||||
// Otherwise we rely on the memory effect interface
|
||||
auto effectInterface = dyn_cast<MemoryEffectOpInterface>(use->getOwner());
|
||||
if (!effectInterface)
|
||||
return false;
|
||||
SmallVector<MemoryEffects::EffectInstance, 2> effects;
|
||||
effectInterface.getEffects(effects);
|
||||
for (auto eff : effects)
|
||||
if (isa<MemoryEffects::Write>(eff.getEffect()) &&
|
||||
eff.getValue() == use->get())
|
||||
return true;
|
||||
return false;
|
||||
};
|
||||
|
||||
// We need to check if this allocated memref is written in this task.
|
||||
// TODO: for now we'll assume that we don't do partial writes or read/writes.
|
||||
for (auto use : aliasedUses)
|
||||
if (hasMemoryWriteEffect(use) &&
|
||||
use->getOwner()->getParentOfType<RT::DataflowTaskOp>() == taskOp) {
|
||||
// We will sink the operation, mark its results as now available.
|
||||
beneficiaryOps.insert(op);
|
||||
for (Value result : op->getResults())
|
||||
availableValues.insert(result);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
LogicalResult coarsenDFTask(RT::DataflowTaskOp taskOp) {
|
||||
Region &taskOpBody = taskOp.body();
|
||||
|
||||
@@ -191,8 +108,6 @@ LogicalResult coarsenDFTask(RT::DataflowTaskOp taskOp) {
|
||||
if (!operandOp)
|
||||
continue;
|
||||
aggregateBeneficiaryOps(operandOp, toBeSunk, availableValues);
|
||||
aggregateOutputMemrefAllocations(operandOp, toBeSunk, availableValues,
|
||||
taskOp);
|
||||
}
|
||||
|
||||
// Insert operations so that the defs get cloned before uses.
|
||||
@@ -283,45 +198,5 @@ std::unique_ptr<mlir::Pass> createBuildDataflowTaskGraphPass(bool debug) {
|
||||
return std::make_unique<BuildDataflowTaskGraphPass>(debug);
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// For documentation see Autopar.td
|
||||
struct FixupDataflowTaskOpsPass
|
||||
: public FixupDataflowTaskOpsBase<FixupDataflowTaskOpsPass> {
|
||||
|
||||
void runOnOperation() override {
|
||||
auto module = getOperation();
|
||||
|
||||
module->walk([](RT::DataflowTaskOp op) {
|
||||
assert(!failed(coarsenDFTask(op)) &&
|
||||
"Failing to sink operations into DFT");
|
||||
});
|
||||
|
||||
// Finally clear up any remaining alloc/dealloc ops that are
|
||||
// meaningless
|
||||
SetVector<Operation *> eraseOps;
|
||||
module->walk([&](memref::AllocOp op) {
|
||||
// If this memref.alloc's only use left is the
|
||||
// dealloc, erase both.
|
||||
if (op->hasOneUse() &&
|
||||
isa<mlir::memref::DeallocOp>(op->use_begin()->getOwner())) {
|
||||
eraseOps.insert(op->use_begin()->getOwner());
|
||||
eraseOps.insert(op);
|
||||
}
|
||||
});
|
||||
for (auto op : eraseOps)
|
||||
op->erase();
|
||||
}
|
||||
|
||||
FixupDataflowTaskOpsPass(bool debug) : debug(debug){};
|
||||
|
||||
protected:
|
||||
bool debug;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
std::unique_ptr<mlir::Pass> createFixupDataflowTaskOpsPass(bool debug) {
|
||||
return std::make_unique<FixupDataflowTaskOpsPass>(debug);
|
||||
}
|
||||
|
||||
} // end namespace concretelang
|
||||
} // end namespace mlir
|
||||
|
||||
Reference in New Issue
Block a user