mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
feat(compiler): Add reinstantiating rewrite patterns for scf.forall and scf.forall.in_parallel
This commit is contained in:
@@ -23,6 +23,26 @@ TypeConvertingReinstantiationPattern<scf::ForOp, false>::matchAndRewrite(
|
||||
scf::ForOp oldOp, mlir::OpConversionPattern<scf::ForOp>::OpAdaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const;
|
||||
|
||||
//
|
||||
// Specializations for ForallOp
|
||||
//
|
||||
template <>
|
||||
mlir::LogicalResult
|
||||
TypeConvertingReinstantiationPattern<scf::ForallOp, false>::matchAndRewrite(
|
||||
scf::ForallOp oldOp,
|
||||
mlir::OpConversionPattern<scf::ForallOp>::OpAdaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const;
|
||||
|
||||
//
|
||||
// Specializations for InParallelOp
|
||||
//
|
||||
template <>
|
||||
mlir::LogicalResult
|
||||
TypeConvertingReinstantiationPattern<scf::InParallelOp, false>::matchAndRewrite(
|
||||
scf::InParallelOp oldOp,
|
||||
mlir::OpConversionPattern<scf::InParallelOp>::OpAdaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const;
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
@@ -39,5 +39,92 @@ TypeConvertingReinstantiationPattern<scf::ForOp, false>::matchAndRewrite(
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
//
|
||||
// Specializations for ForallOp
|
||||
//
|
||||
template <>
|
||||
mlir::LogicalResult
|
||||
TypeConvertingReinstantiationPattern<scf::ForallOp, false>::matchAndRewrite(
|
||||
scf::ForallOp oldOp,
|
||||
mlir::OpConversionPattern<scf::ForallOp>::OpAdaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const {
|
||||
// Create new forall operation with empty body, but converted iter
|
||||
// args
|
||||
llvm::SmallVector<mlir::OpFoldResult> lbs = getMixedValues(
|
||||
adaptor.getStaticLowerBound(), adaptor.getDynamicLowerBound(), rewriter);
|
||||
llvm::SmallVector<mlir::OpFoldResult> ubs = getMixedValues(
|
||||
adaptor.getStaticUpperBound(), adaptor.getDynamicUpperBound(), rewriter);
|
||||
llvm::SmallVector<mlir::OpFoldResult> step = getMixedValues(
|
||||
adaptor.getStaticStep(), adaptor.getDynamicStep(), rewriter);
|
||||
|
||||
rewriter.setInsertionPoint(oldOp);
|
||||
|
||||
scf::ForallOp newForallOp = rewriter.create<scf::ForallOp>(
|
||||
oldOp.getLoc(), lbs, ubs, step, adaptor.getOutputs(),
|
||||
adaptor.getMapping());
|
||||
|
||||
newForallOp->setAttrs(adaptor.getAttributes());
|
||||
|
||||
// Move operations from old for op to new one
|
||||
auto &newOperations = newForallOp.getBody()->getOperations();
|
||||
mlir::Block *oldBody = oldOp.getBody();
|
||||
|
||||
newOperations.splice(newOperations.begin(), oldBody->getOperations(),
|
||||
oldBody->begin(), std::prev(oldBody->end()));
|
||||
|
||||
// Move operations from `scf.forall.in_parallel` terminator of the
|
||||
// old op to the terminator of the new op
|
||||
|
||||
mlir::scf::InParallelOp oldTerminator =
|
||||
llvm::dyn_cast<mlir::scf::InParallelOp>(*std::prev(oldBody->end()));
|
||||
|
||||
assert(oldTerminator && "Last operation of `scf.forall` op expected be a "
|
||||
"`scf.forall.in_parallel` op");
|
||||
|
||||
mlir::scf::InParallelOp newTerminator = newForallOp.getTerminator();
|
||||
|
||||
mlir::Block::OpListType &oldTerminatorOps =
|
||||
oldTerminator.getRegion().getBlocks().begin()->getOperations();
|
||||
mlir::Block::OpListType &newTerminatorOps =
|
||||
newTerminator.getRegion().getBlocks().begin()->getOperations();
|
||||
|
||||
newTerminatorOps.splice(newTerminatorOps.begin(), oldTerminatorOps,
|
||||
oldTerminatorOps.begin(), oldTerminatorOps.end());
|
||||
|
||||
// Remap iter args and IV
|
||||
for (auto argsPair : llvm::zip(oldOp.getBody()->getArguments(),
|
||||
newForallOp.getBody()->getArguments())) {
|
||||
std::get<0>(argsPair).replaceAllUsesWith(std::get<1>(argsPair));
|
||||
}
|
||||
|
||||
rewriter.replaceOp(oldOp, newForallOp.getResults());
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
//
|
||||
// Specializations for InParallelOp
|
||||
//
|
||||
template <>
|
||||
mlir::LogicalResult
|
||||
TypeConvertingReinstantiationPattern<scf::InParallelOp, false>::matchAndRewrite(
|
||||
scf::InParallelOp oldOp,
|
||||
mlir::OpConversionPattern<scf::InParallelOp>::OpAdaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const {
|
||||
// Create new for loop with empty body, but converted iter args
|
||||
scf::InParallelOp newInParallelOp =
|
||||
rewriter.replaceOpWithNewOp<scf::InParallelOp>(oldOp);
|
||||
|
||||
// Move operations from old for op to new one
|
||||
auto &newOperations = newInParallelOp.getBody()->getOperations();
|
||||
mlir::Block *oldBody = oldOp.getBody();
|
||||
|
||||
newOperations.splice(newOperations.begin(), oldBody->getOperations(),
|
||||
oldBody->begin(), oldBody->end());
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
Reference in New Issue
Block a user