mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
The switch to reinstantiating conversion patterns for the conversion
from FHE to TFHE in commit 73fd6c5fe7
caused all attributes of `scf.for` operations to be dropped during the
conversion. This included the custom attribute `parallel`, which is
exploited further down the compilation pipeline to generate parallel
code. As a result, the performance of end-to-end benchmarks dropped
significantly.
This patch copies all attributes of `scf.for` operations upon
reinstantiation, which solves the performance regression.
43 lines
1.6 KiB
C++
43 lines
1.6 KiB
C++
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
|
// Exceptions. See
|
|
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
|
// for license information.
|
|
|
|
#include "concretelang/Conversion/Utils/Dialects/SCF.h"
|
|
#include "mlir/Transforms/RegionUtils.h"
|
|
#include <mlir/IR/BlockAndValueMapping.h>
|
|
|
|
namespace mlir {
|
|
namespace concretelang {
|
|
template <>
|
|
mlir::LogicalResult
|
|
TypeConvertingReinstantiationPattern<scf::ForOp, false>::matchAndRewrite(
|
|
scf::ForOp oldOp, mlir::OpConversionPattern<scf::ForOp>::OpAdaptor adaptor,
|
|
mlir::ConversionPatternRewriter &rewriter) const {
|
|
// Create new for loop with empty body, but converted iter args
|
|
scf::ForOp newForOp = rewriter.replaceOpWithNewOp<scf::ForOp>(
|
|
oldOp, adaptor.getLowerBound(), adaptor.getUpperBound(),
|
|
adaptor.getStep(), adaptor.getInitArgs(),
|
|
[&](OpBuilder &builder, Location loc, Value iv, ValueRange args) {});
|
|
|
|
newForOp->setAttrs(adaptor.getAttributes());
|
|
|
|
// Move operations from old for op to new one
|
|
auto &newOperations = newForOp.getBody()->getOperations();
|
|
mlir::Block *oldBody = oldOp.getBody();
|
|
|
|
newOperations.splice(newOperations.begin(), oldBody->getOperations(),
|
|
oldBody->begin(), oldBody->end());
|
|
|
|
// Remap iter args and IV
|
|
for (auto argsPair : llvm::zip(oldOp.getBody()->getArguments(),
|
|
newForOp.getBody()->getArguments())) {
|
|
replaceAllUsesInRegionWith(std::get<0>(argsPair), std::get<1>(argsPair),
|
|
newForOp.getRegion());
|
|
}
|
|
|
|
return mlir::success();
|
|
}
|
|
} // namespace concretelang
|
|
} // namespace mlir
|