mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-18 08:31:31 -05:00
The current scheme used by reinstantiating conversion patterns in `lib/Conversion/Utils/Dialects` for operations with blocks is to create a new operation with empty blocks, to move the operations from the old blocks and then to replace any references to block arguments. However, such in-place updates of the types of block arguments leave conversion patterns for operations nested in the blocks without the ability to determine the original types of values from before the update. This change uses proper signature conversion for block arguments, such that the original types of block arguments with converted types is preserved, while the new types are made available through the dialect conversion infrastructure via the respective adaptors.
89 lines
3.5 KiB
C++
89 lines
3.5 KiB
C++
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
|
// Exceptions. See
|
|
// https://github.com/zama-ai/concrete/blob/main/LICENSE.txt
|
|
// for license information.
|
|
|
|
#include "concretelang/Conversion/Utils/Utils.h"
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"
|
|
#include "mlir/Transforms/RegionUtils.h"
|
|
|
|
namespace mlir {
|
|
namespace concretelang {
|
|
|
|
mlir::Type getDynamicMemrefWithUnknownOffset(mlir::RewriterBase &rewriter,
|
|
size_t rank) {
|
|
std::vector<int64_t> shape(rank, mlir::ShapedType::kDynamic);
|
|
mlir::AffineExpr expr = rewriter.getAffineSymbolExpr(0);
|
|
for (size_t i = 0; i < rank; i++) {
|
|
expr = expr +
|
|
(rewriter.getAffineDimExpr(i) * rewriter.getAffineSymbolExpr(i + 1));
|
|
}
|
|
return mlir::MemRefType::get(
|
|
shape, rewriter.getI64Type(),
|
|
mlir::AffineMap::get(rank, rank + 1, expr, rewriter.getContext()));
|
|
}
|
|
|
|
// Returns `memref.cast %0 : memref<...xAxT> to memref<...x?xT>`
|
|
mlir::Value getCastedMemRef(mlir::RewriterBase &rewriter, mlir::Value value) {
|
|
mlir::Type valueType = value.getType();
|
|
|
|
if (auto memrefTy = valueType.dyn_cast_or_null<mlir::MemRefType>()) {
|
|
return rewriter.create<mlir::memref::CastOp>(
|
|
value.getLoc(),
|
|
getDynamicMemrefWithUnknownOffset(rewriter, memrefTy.getShape().size()),
|
|
value);
|
|
} else {
|
|
return value;
|
|
}
|
|
}
|
|
|
|
mlir::Value globalMemrefFromArrayAttr(mlir::RewriterBase &rewriter,
|
|
mlir::Location loc,
|
|
mlir::ArrayAttr arrAttr) {
|
|
mlir::Type type =
|
|
mlir::RankedTensorType::get({(int)arrAttr.size()}, rewriter.getI64Type());
|
|
std::vector<int64_t> values;
|
|
for (auto a : arrAttr) {
|
|
values.push_back(a.cast<mlir::IntegerAttr>().getValue().getZExtValue());
|
|
}
|
|
auto denseAttr = rewriter.getI64TensorAttr(values);
|
|
auto cstOp = rewriter.create<mlir::arith::ConstantOp>(loc, denseAttr, type);
|
|
auto globalMemref = mlir::bufferization::getGlobalFor(cstOp, 0);
|
|
rewriter.eraseOp(cstOp);
|
|
assert(!mlir::failed(globalMemref));
|
|
auto globalRef = rewriter.create<mlir::memref::GetGlobalOp>(
|
|
loc, (*globalMemref).getType(), (*globalMemref).getName());
|
|
return mlir::concretelang::getCastedMemRef(rewriter, globalRef);
|
|
}
|
|
|
|
// Converts an operation `op` with nested blocks using a type
|
|
// converter and a conversion pattern rewriter, such that the newly
|
|
// created operation uses the operands specified in `newOperands` and
|
|
// returns a value of the types `newResultTypes`.
|
|
mlir::Operation *
|
|
convertOpWithBlocks(mlir::Operation *op, mlir::ValueRange newOperands,
|
|
mlir::TypeRange newResultTypes,
|
|
mlir::TypeConverter &typeConverter,
|
|
mlir::ConversionPatternRewriter &rewriter) {
|
|
mlir::OperationState state(op->getLoc(), op->getName().getStringRef(),
|
|
newOperands, newResultTypes, op->getAttrs(),
|
|
op->getSuccessors());
|
|
|
|
for (Region ®ion : op->getRegions()) {
|
|
Region *newRegion = state.addRegion();
|
|
rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin());
|
|
TypeConverter::SignatureConversion result(newRegion->getNumArguments());
|
|
(void)typeConverter.convertSignatureArgs(newRegion->getArgumentTypes(),
|
|
result);
|
|
rewriter.applySignatureConversion(newRegion, result);
|
|
}
|
|
|
|
Operation *newOp = rewriter.create(state);
|
|
rewriter.replaceOp(op, newOp->getResults());
|
|
|
|
return newOp;
|
|
}
|
|
} // namespace concretelang
|
|
} // namespace mlir
|