diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/RT/Analysis/BuildDataflowTaskGraph.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/RT/Analysis/BuildDataflowTaskGraph.cpp index c533d5bf0..d1d57fa31 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/RT/Analysis/BuildDataflowTaskGraph.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/RT/Analysis/BuildDataflowTaskGraph.cpp @@ -5,6 +5,7 @@ #include +#include "concretelang/Dialect/FHE/Interfaces/FHEInterfaces.h" #include #include #include @@ -43,15 +44,19 @@ namespace { // TODO: adjust these two functions based on cost model static bool isCandidateForTask(Operation *op) { - return isa< - FHE::ApplyLookupTableEintOp, FHELinalg::MatMulEintIntOp, - FHELinalg::AddEintIntOp, FHELinalg::AddEintOp, FHELinalg::SubIntEintOp, - FHELinalg::SubEintIntOp, FHELinalg::SubEintOp, FHELinalg::NegEintOp, - FHELinalg::MulEintIntOp, FHELinalg::ApplyLookupTableEintOp, - FHELinalg::ApplyMultiLookupTableEintOp, - FHELinalg::ApplyMappedLookupTableEintOp, FHELinalg::Dot, - FHELinalg::MatMulEintIntOp, FHELinalg::MatMulIntEintOp, FHELinalg::SumOp, - FHELinalg::ConcatOp, FHELinalg::Conv2dOp, FHELinalg::TransposeOp>(op); + // if it's a linalg.genric operation with encrypted inputs + if (auto genericOp = mlir::dyn_cast(op)) { + for (auto input : genericOp.getInputs()) { + if ((input.getType().isa() && + mlir::dyn_cast(input.getType()) + .getElementType() + .isa()) || + input.getType().isa()) { + return true; + } + } + } + return isa(op); } /// Identify operations that are beneficial to aggregate into tasks. These @@ -137,8 +142,9 @@ struct BuildDataflowTaskGraphPass module.walk([&](mlir::func::FuncOp func) { if (!func->getAttr("_dfr_work_function_attribute")) - func.walk( - [&](mlir::Operation *childOp) { this->processOperation(childOp); }); + func.walk([&](mlir::Operation *childOp) { + return this->processOperation(childOp); + }); // Perform simplifications, in particular DCE here in case some // of the operations sunk in tasks are no longer needed in the @@ -154,7 +160,7 @@ struct BuildDataflowTaskGraphPass BuildDataflowTaskGraphPass(bool debug) : debug(debug){}; protected: - void processOperation(mlir::Operation *op) { + mlir::WalkResult processOperation(mlir::Operation *op) { if (isCandidateForTask(op)) { IRMapping map; Region &opBody = getOperation().getBody(); @@ -187,7 +193,10 @@ protected: opBody); // Once uses are re-targeted to the task, delete the operation op->erase(); + + return mlir::WalkResult::interrupt(); } + return mlir::WalkResult::advance(); } bool debug;