mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-17 16:11:26 -05:00
feat(compiler): support linalg.genric instead of FHELinalg ops in DF parallelization
This commit is contained in:
@@ -5,6 +5,7 @@
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "concretelang/Dialect/FHE/Interfaces/FHEInterfaces.h"
|
||||
#include <concretelang/Dialect/FHE/IR/FHEDialect.h>
|
||||
#include <concretelang/Dialect/FHE/IR/FHEOps.h>
|
||||
#include <concretelang/Dialect/FHE/IR/FHETypes.h>
|
||||
@@ -43,15 +44,19 @@ namespace {
|
||||
|
||||
// TODO: adjust these two functions based on cost model
|
||||
static bool isCandidateForTask(Operation *op) {
|
||||
return isa<
|
||||
FHE::ApplyLookupTableEintOp, FHELinalg::MatMulEintIntOp,
|
||||
FHELinalg::AddEintIntOp, FHELinalg::AddEintOp, FHELinalg::SubIntEintOp,
|
||||
FHELinalg::SubEintIntOp, FHELinalg::SubEintOp, FHELinalg::NegEintOp,
|
||||
FHELinalg::MulEintIntOp, FHELinalg::ApplyLookupTableEintOp,
|
||||
FHELinalg::ApplyMultiLookupTableEintOp,
|
||||
FHELinalg::ApplyMappedLookupTableEintOp, FHELinalg::Dot,
|
||||
FHELinalg::MatMulEintIntOp, FHELinalg::MatMulIntEintOp, FHELinalg::SumOp,
|
||||
FHELinalg::ConcatOp, FHELinalg::Conv2dOp, FHELinalg::TransposeOp>(op);
|
||||
// if it's a linalg.genric operation with encrypted inputs
|
||||
if (auto genericOp = mlir::dyn_cast<mlir::linalg::GenericOp>(op)) {
|
||||
for (auto input : genericOp.getInputs()) {
|
||||
if ((input.getType().isa<ShapedType>() &&
|
||||
mlir::dyn_cast<ShapedType>(input.getType())
|
||||
.getElementType()
|
||||
.isa<FHE::FheIntegerInterface>()) ||
|
||||
input.getType().isa<FHE::FheIntegerInterface>()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return isa<FHE::ApplyLookupTableEintOp>(op);
|
||||
}
|
||||
|
||||
/// Identify operations that are beneficial to aggregate into tasks. These
|
||||
@@ -137,8 +142,9 @@ struct BuildDataflowTaskGraphPass
|
||||
|
||||
module.walk([&](mlir::func::FuncOp func) {
|
||||
if (!func->getAttr("_dfr_work_function_attribute"))
|
||||
func.walk(
|
||||
[&](mlir::Operation *childOp) { this->processOperation(childOp); });
|
||||
func.walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *childOp) {
|
||||
return this->processOperation(childOp);
|
||||
});
|
||||
|
||||
// Perform simplifications, in particular DCE here in case some
|
||||
// of the operations sunk in tasks are no longer needed in the
|
||||
@@ -154,7 +160,7 @@ struct BuildDataflowTaskGraphPass
|
||||
BuildDataflowTaskGraphPass(bool debug) : debug(debug){};
|
||||
|
||||
protected:
|
||||
void processOperation(mlir::Operation *op) {
|
||||
mlir::WalkResult processOperation(mlir::Operation *op) {
|
||||
if (isCandidateForTask(op)) {
|
||||
IRMapping map;
|
||||
Region &opBody = getOperation().getBody();
|
||||
@@ -187,7 +193,10 @@ protected:
|
||||
opBody);
|
||||
// Once uses are re-targeted to the task, delete the operation
|
||||
op->erase();
|
||||
|
||||
return mlir::WalkResult::interrupt();
|
||||
}
|
||||
return mlir::WalkResult::advance();
|
||||
}
|
||||
|
||||
bool debug;
|
||||
|
||||
Reference in New Issue
Block a user