feat(compiler): support linalg.genric instead of FHELinalg ops in DF parallelization

This commit is contained in:
youben11
2023-09-22 10:37:52 +01:00
committed by Ayoub Benaissa
parent 88dd13756a
commit 64d0741c1b

View File

@@ -5,6 +5,7 @@
#include <iostream>
#include "concretelang/Dialect/FHE/Interfaces/FHEInterfaces.h"
#include <concretelang/Dialect/FHE/IR/FHEDialect.h>
#include <concretelang/Dialect/FHE/IR/FHEOps.h>
#include <concretelang/Dialect/FHE/IR/FHETypes.h>
@@ -43,15 +44,19 @@ namespace {
// TODO: adjust these two functions based on cost model
static bool isCandidateForTask(Operation *op) {
return isa<
FHE::ApplyLookupTableEintOp, FHELinalg::MatMulEintIntOp,
FHELinalg::AddEintIntOp, FHELinalg::AddEintOp, FHELinalg::SubIntEintOp,
FHELinalg::SubEintIntOp, FHELinalg::SubEintOp, FHELinalg::NegEintOp,
FHELinalg::MulEintIntOp, FHELinalg::ApplyLookupTableEintOp,
FHELinalg::ApplyMultiLookupTableEintOp,
FHELinalg::ApplyMappedLookupTableEintOp, FHELinalg::Dot,
FHELinalg::MatMulEintIntOp, FHELinalg::MatMulIntEintOp, FHELinalg::SumOp,
FHELinalg::ConcatOp, FHELinalg::Conv2dOp, FHELinalg::TransposeOp>(op);
// if it's a linalg.genric operation with encrypted inputs
if (auto genericOp = mlir::dyn_cast<mlir::linalg::GenericOp>(op)) {
for (auto input : genericOp.getInputs()) {
if ((input.getType().isa<ShapedType>() &&
mlir::dyn_cast<ShapedType>(input.getType())
.getElementType()
.isa<FHE::FheIntegerInterface>()) ||
input.getType().isa<FHE::FheIntegerInterface>()) {
return true;
}
}
}
return isa<FHE::ApplyLookupTableEintOp>(op);
}
/// Identify operations that are beneficial to aggregate into tasks. These
@@ -137,8 +142,9 @@ struct BuildDataflowTaskGraphPass
module.walk([&](mlir::func::FuncOp func) {
if (!func->getAttr("_dfr_work_function_attribute"))
func.walk(
[&](mlir::Operation *childOp) { this->processOperation(childOp); });
func.walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *childOp) {
return this->processOperation(childOp);
});
// Perform simplifications, in particular DCE here in case some
// of the operations sunk in tasks are no longer needed in the
@@ -154,7 +160,7 @@ struct BuildDataflowTaskGraphPass
BuildDataflowTaskGraphPass(bool debug) : debug(debug){};
protected:
void processOperation(mlir::Operation *op) {
mlir::WalkResult processOperation(mlir::Operation *op) {
if (isCandidateForTask(op)) {
IRMapping map;
Region &opBody = getOperation().getBody();
@@ -187,7 +193,10 @@ protected:
opBody);
// Once uses are re-targeted to the task, delete the operation
op->erase();
return mlir::WalkResult::interrupt();
}
return mlir::WalkResult::advance();
}
bool debug;