mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
feat(compiler): coarsen task granularity by aggregating lightweight operations.
This commit is contained in:
@@ -24,6 +24,7 @@
|
||||
#include <mlir/IR/BlockAndValueMapping.h>
|
||||
#include <mlir/IR/Builders.h>
|
||||
#include <mlir/IR/BuiltinAttributes.h>
|
||||
#include <mlir/IR/OperationSupport.h>
|
||||
#include <mlir/IR/PatternMatch.h>
|
||||
#include <mlir/Support/LLVM.h>
|
||||
#include <mlir/Support/LogicalResult.h>
|
||||
@@ -43,46 +44,49 @@ namespace {
|
||||
// TODO: adjust these two functions based on cost model
|
||||
static bool isCandidateForTask(Operation *op) {
|
||||
return isa<
|
||||
FHE::AddEintIntOp, FHE::AddEintOp, FHE::SubIntEintOp, FHE::MulEintIntOp,
|
||||
FHE::ApplyLookupTableEintOp, FHELinalg::MatMulIntEintOp,
|
||||
FHELinalg::MatMulEintIntOp, FHELinalg::AddEintIntOp, FHELinalg::AddEintOp,
|
||||
FHELinalg::SubIntEintOp, FHELinalg::NegEintOp, FHELinalg::MulEintIntOp,
|
||||
FHELinalg::ApplyLookupTableEintOp, FHELinalg::ApplyMultiLookupTableEintOp,
|
||||
FHE::ApplyLookupTableEintOp, FHELinalg::MatMulEintIntOp,
|
||||
FHELinalg::AddEintIntOp, FHELinalg::AddEintOp, FHELinalg::SubIntEintOp,
|
||||
FHELinalg::SubEintIntOp, FHELinalg::SubEintOp, FHELinalg::NegEintOp,
|
||||
FHELinalg::MulEintIntOp, FHELinalg::ApplyLookupTableEintOp,
|
||||
FHELinalg::ApplyMultiLookupTableEintOp,
|
||||
FHELinalg::ApplyMappedLookupTableEintOp, FHELinalg::Dot,
|
||||
FHELinalg::MatMulEintIntOp, FHELinalg::MatMulIntEintOp, FHELinalg::SumOp,
|
||||
FHELinalg::ConcatOp>(op);
|
||||
FHELinalg::ConcatOp, FHELinalg::Conv2dOp, FHELinalg::TransposeOp>(op);
|
||||
}
|
||||
|
||||
/// Identify operations that are beneficial to sink into tasks. These
|
||||
/// Identify operations that are beneficial to aggregate into tasks. These
|
||||
/// operations must not have side-effects and not be `isCandidateForTask`
|
||||
static bool isSinkingBeneficiary(Operation *op) {
|
||||
return isa<FHE::ZeroEintOp, arith::ConstantOp, memref::DimOp, arith::SelectOp,
|
||||
mlir::arith::CmpIOp, mlir::memref::GetGlobalOp>(op);
|
||||
static bool isAggregatingBeneficiary(Operation *op) {
|
||||
return isa<FHE::ZeroEintOp, FHE::ZeroTensorOp, FHE::AddEintIntOp,
|
||||
FHE::AddEintOp, FHE::SubIntEintOp, FHE::SubEintIntOp,
|
||||
FHE::MulEintIntOp, FHE::SubEintOp, FHE::NegEintOp,
|
||||
FHELinalg::FromElementOp, arith::ConstantOp, memref::DimOp,
|
||||
arith::SelectOp, mlir::arith::CmpIOp, memref::GetGlobalOp,
|
||||
memref::CastOp>(op);
|
||||
}
|
||||
|
||||
static bool
|
||||
extractBeneficiaryOps(Operation *op, SetVector<Value> existingDependencies,
|
||||
SetVector<Operation *> &beneficiaryOps,
|
||||
llvm::SmallPtrSetImpl<Value> &availableValues) {
|
||||
aggregateBeneficiaryOps(Operation *op, SetVector<Operation *> &beneficiaryOps,
|
||||
llvm::SmallPtrSetImpl<Value> &availableValues) {
|
||||
if (beneficiaryOps.count(op))
|
||||
return true;
|
||||
|
||||
if (!isSinkingBeneficiary(op))
|
||||
if (!isAggregatingBeneficiary(op))
|
||||
return false;
|
||||
|
||||
for (Value operand : op->getOperands()) {
|
||||
// It is already visible in the kernel, keep going.
|
||||
if (availableValues.count(operand))
|
||||
continue;
|
||||
// Else check whether it can be made available via sinking or already is a
|
||||
// dependency.
|
||||
Operation *definingOp = operand.getDefiningOp();
|
||||
if ((!definingOp ||
|
||||
!extractBeneficiaryOps(definingOp, existingDependencies,
|
||||
beneficiaryOps, availableValues)) &&
|
||||
!existingDependencies.count(operand))
|
||||
return false;
|
||||
// Gather the new potential dependences created by sinking this op.
|
||||
llvm::SmallPtrSet<Value, 4> newDependencesIfSunk;
|
||||
for (Value operand : op->getOperands())
|
||||
if (!availableValues.count(operand))
|
||||
newDependencesIfSunk.insert(operand);
|
||||
|
||||
// We further attempt to sink any new dependence
|
||||
for (auto dep : newDependencesIfSunk) {
|
||||
Operation *definingOp = dep.getDefiningOp();
|
||||
if (definingOp)
|
||||
aggregateBeneficiaryOps(definingOp, beneficiaryOps, availableValues);
|
||||
}
|
||||
|
||||
// We will sink the operation, mark its results as now available.
|
||||
beneficiaryOps.insert(op);
|
||||
for (Value result : op->getResults())
|
||||
@@ -90,12 +94,18 @@ extractBeneficiaryOps(Operation *op, SetVector<Value> existingDependencies,
|
||||
return true;
|
||||
}
|
||||
|
||||
static func::FuncOp getCalledFunction(CallOpInterface callOp) {
|
||||
SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>();
|
||||
static bool isFunctionCallName(OpOperand *use, StringRef name) {
|
||||
func::CallOp call = dyn_cast_or_null<mlir::func::CallOp>(use->getOwner());
|
||||
if (!call)
|
||||
return false;
|
||||
SymbolRefAttr sym = call.getCallableForCallee().dyn_cast<SymbolRefAttr>();
|
||||
if (!sym)
|
||||
return nullptr;
|
||||
return dyn_cast_or_null<func::FuncOp>(
|
||||
SymbolTable::lookupNearestSymbolFrom(callOp, sym));
|
||||
return false;
|
||||
func::FuncOp called = dyn_cast_or_null<func::FuncOp>(
|
||||
SymbolTable::lookupNearestSymbolFrom(call, sym));
|
||||
if (!called)
|
||||
return false;
|
||||
return called.getName() == name;
|
||||
}
|
||||
|
||||
static void getAliasedUses(Value val, DenseSet<OpOperand *> &aliasedUses) {
|
||||
@@ -106,9 +116,8 @@ static void getAliasedUses(Value val, DenseSet<OpOperand *> &aliasedUses) {
|
||||
}
|
||||
}
|
||||
|
||||
static bool extractOutputMemrefAllocations(
|
||||
Operation *op, SetVector<Value> existingDependencies,
|
||||
SetVector<Operation *> &beneficiaryOps,
|
||||
static bool aggregateOutputMemrefAllocations(
|
||||
Operation *op, SetVector<Operation *> &beneficiaryOps,
|
||||
llvm::SmallPtrSetImpl<Value> &availableValues, RT::DataflowTaskOp taskOp) {
|
||||
if (beneficiaryOps.count(op))
|
||||
return true;
|
||||
@@ -125,29 +134,19 @@ static bool extractOutputMemrefAllocations(
|
||||
// Call ops targeting concrete-ffi do not have memory effects
|
||||
// interface, so handle apart.
|
||||
// TODO: this could be handled better in BConcrete or higher.
|
||||
if (isa<mlir::func::CallOp>(use->getOwner())) {
|
||||
if (getCalledFunction(use->getOwner()).getName() ==
|
||||
"memref_expand_lut_in_trivial_glwe_ct_u64" ||
|
||||
getCalledFunction(use->getOwner()).getName() ==
|
||||
"memref_add_lwe_ciphertexts_u64" ||
|
||||
getCalledFunction(use->getOwner()).getName() ==
|
||||
"memref_add_plaintext_lwe_ciphertext_u64" ||
|
||||
getCalledFunction(use->getOwner()).getName() ==
|
||||
"memref_mul_cleartext_lwe_ciphertext_u64" ||
|
||||
getCalledFunction(use->getOwner()).getName() ==
|
||||
"memref_negate_lwe_ciphertext_u64" ||
|
||||
getCalledFunction(use->getOwner()).getName() ==
|
||||
"memref_keyswitch_lwe_u64" ||
|
||||
getCalledFunction(use->getOwner()).getName() ==
|
||||
"memref_bootstrap_lwe_u64")
|
||||
if (use->getOwner()->getOperand(0) == use->get())
|
||||
return true;
|
||||
if (isFunctionCallName(use, "memref_expand_lut_in_trivial_glwe_ct_u64") ||
|
||||
isFunctionCallName(use, "memref_add_lwe_ciphertexts_u64") ||
|
||||
isFunctionCallName(use, "memref_add_plaintext_lwe_ciphertext_u64") ||
|
||||
isFunctionCallName(use, "memref_mul_cleartext_lwe_ciphertext_u64") ||
|
||||
isFunctionCallName(use, "memref_negate_lwe_ciphertext_u64") ||
|
||||
isFunctionCallName(use, "memref_keyswitch_lwe_u64") ||
|
||||
isFunctionCallName(use, "memref_bootstrap_lwe_u64"))
|
||||
if (use->getOwner()->getOperand(0) == use->get())
|
||||
return true;
|
||||
|
||||
if (getCalledFunction(use->getOwner()).getName() ==
|
||||
"memref_copy_one_rank")
|
||||
if (use->getOwner()->getOperand(1) == use->get())
|
||||
return true;
|
||||
}
|
||||
if (isFunctionCallName(use, "memref_copy_one_rank"))
|
||||
if (use->getOwner()->getOperand(1) == use->get())
|
||||
return true;
|
||||
|
||||
// Otherwise we rely on the memory effect interface
|
||||
auto effectInterface = dyn_cast<MemoryEffectOpInterface>(use->getOwner());
|
||||
@@ -176,7 +175,7 @@ static bool extractOutputMemrefAllocations(
|
||||
return false;
|
||||
}
|
||||
|
||||
LogicalResult sinkOperationsIntoDFTask(RT::DataflowTaskOp taskOp) {
|
||||
LogicalResult coarsenDFTask(RT::DataflowTaskOp taskOp) {
|
||||
Region &taskOpBody = taskOp.body();
|
||||
|
||||
// Identify uses from values defined outside of the scope.
|
||||
@@ -184,14 +183,15 @@ LogicalResult sinkOperationsIntoDFTask(RT::DataflowTaskOp taskOp) {
|
||||
getUsedValuesDefinedAbove(taskOpBody, sinkCandidates);
|
||||
|
||||
SetVector<Operation *> toBeSunk;
|
||||
llvm::SmallPtrSet<Value, 4> availableValues;
|
||||
llvm::SmallPtrSet<Value, 4> availableValues(sinkCandidates.begin(),
|
||||
sinkCandidates.end());
|
||||
for (Value operand : sinkCandidates) {
|
||||
Operation *operandOp = operand.getDefiningOp();
|
||||
if (!operandOp)
|
||||
continue;
|
||||
extractBeneficiaryOps(operandOp, sinkCandidates, toBeSunk, availableValues);
|
||||
extractOutputMemrefAllocations(operandOp, sinkCandidates, toBeSunk,
|
||||
availableValues, taskOp);
|
||||
aggregateBeneficiaryOps(operandOp, toBeSunk, availableValues);
|
||||
aggregateOutputMemrefAllocations(operandOp, toBeSunk, availableValues,
|
||||
taskOp);
|
||||
}
|
||||
|
||||
// Insert operations so that the defs get cloned before uses.
|
||||
@@ -202,16 +202,13 @@ LogicalResult sinkOperationsIntoDFTask(RT::DataflowTaskOp taskOp) {
|
||||
Operation *clonedOp = builder.clone(*op, map);
|
||||
for (auto pair : llvm::zip(op->getResults(), clonedOp->getResults()))
|
||||
replaceAllUsesInRegionWith(std::get<0>(pair), std::get<1>(pair),
|
||||
taskOp.body());
|
||||
// Once this is sunk, remove all operands of the DFT covered by this
|
||||
for (auto result : op->getResults())
|
||||
for (auto operand : llvm::enumerate(taskOp.getOperands()))
|
||||
if (operand.value() == result) {
|
||||
taskOp->eraseOperand(operand.index());
|
||||
// Once removed, we assume there are no duplicates
|
||||
break;
|
||||
}
|
||||
taskOpBody);
|
||||
}
|
||||
|
||||
SetVector<Value> deps;
|
||||
getUsedValuesDefinedAbove(taskOpBody, deps);
|
||||
taskOp->setOperands(deps.takeVector());
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -251,11 +248,14 @@ protected:
|
||||
builder.setInsertionPointAfter(op);
|
||||
auto dftop = builder.create<RT::DataflowTaskOp>(
|
||||
op->getLoc(), op->getResultTypes(), op->getOperands());
|
||||
|
||||
// Add the operation to the task
|
||||
OpBuilder tbbuilder(dftop.body());
|
||||
Operation *clonedOp = tbbuilder.clone(*op, map);
|
||||
// Add sinkable operations to the task
|
||||
assert(!failed(sinkOperationsIntoDFTask(dftop)) &&
|
||||
|
||||
// Coarsen granularity by aggregating all dependence related
|
||||
// lower-weight operations.
|
||||
assert(!failed(coarsenDFTask(dftop)) &&
|
||||
"Failing to sink operations into DFT");
|
||||
|
||||
// Add terminator
|
||||
@@ -282,44 +282,6 @@ std::unique_ptr<mlir::Pass> createBuildDataflowTaskGraphPass(bool debug) {
|
||||
return std::make_unique<BuildDataflowTaskGraphPass>(debug);
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// Marker to avoid infinite recursion of the rewriting pattern
|
||||
static const mlir::StringLiteral kTransformMarker =
|
||||
"_internal_RT_FixDataflowTaskOpInputsPattern_marker__";
|
||||
|
||||
class FixDataflowTaskOpInputsPattern
|
||||
: public mlir::OpRewritePattern<RT::DataflowTaskOp> {
|
||||
public:
|
||||
FixDataflowTaskOpInputsPattern(mlir::MLIRContext *context)
|
||||
: mlir::OpRewritePattern<RT::DataflowTaskOp>(
|
||||
context, ::mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(RT::DataflowTaskOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
mlir::OpBuilder::InsertionGuard guard(rewriter);
|
||||
|
||||
if (op->hasAttr(kTransformMarker))
|
||||
return failure();
|
||||
|
||||
// Identify which values need to be passed as dependences to the
|
||||
// task - this is very conservative and will add constants, index
|
||||
// operations, etc. A simplification will occur later.
|
||||
SetVector<Value> deps;
|
||||
getUsedValuesDefinedAbove(op.body(), deps);
|
||||
auto newop = rewriter.create<RT::DataflowTaskOp>(
|
||||
op.getLoc(), op.getResultTypes(), deps.getArrayRef());
|
||||
rewriter.mergeBlocks(op.getBody(), newop.getBody(),
|
||||
newop.getBody()->getArguments());
|
||||
rewriter.replaceOp(op, {newop.getResults()});
|
||||
|
||||
// Mark this as processed to prevent infinite loop
|
||||
newop.getOperation()->setAttr(kTransformMarker, rewriter.getUnitAttr());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
/// For documentation see Autopar.td
|
||||
struct FixupDataflowTaskOpsPass
|
||||
@@ -329,21 +291,25 @@ struct FixupDataflowTaskOpsPass
|
||||
auto module = getOperation();
|
||||
auto *context = &getContext();
|
||||
|
||||
RewritePatternSet patterns(context);
|
||||
patterns.add<FixDataflowTaskOpInputsPattern>(context);
|
||||
|
||||
if (mlir::applyPatternsAndFoldGreedily(module, std::move(patterns))
|
||||
.failed())
|
||||
signalPassFailure();
|
||||
|
||||
// Clear mark and sink any newly created constants or indexing
|
||||
// operations, etc. to reduce the number of input dependences to
|
||||
// the task
|
||||
module->walk([](RT::DataflowTaskOp op) {
|
||||
op.getOperation()->removeAttr(kTransformMarker);
|
||||
assert(!failed(sinkOperationsIntoDFTask(op)) &&
|
||||
assert(!failed(coarsenDFTask(op)) &&
|
||||
"Failing to sink operations into DFT");
|
||||
});
|
||||
|
||||
// Finally clear up any remaining alloc/dealloc ops that are
|
||||
// meaningless
|
||||
SetVector<Operation *> eraseOps;
|
||||
module->walk([&](memref::AllocOp op) {
|
||||
// If this memref.alloc's only use left is the
|
||||
// dealloc, erase both.
|
||||
if (op->hasOneUse() &&
|
||||
isa<mlir::memref::DeallocOp>(op->use_begin()->getOwner())) {
|
||||
eraseOps.insert(op->use_begin()->getOwner());
|
||||
eraseOps.insert(op);
|
||||
}
|
||||
});
|
||||
for (auto op : eraseOps)
|
||||
op->erase();
|
||||
}
|
||||
|
||||
FixupDataflowTaskOpsPass(bool debug) : debug(debug){};
|
||||
|
||||
Reference in New Issue
Block a user