mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
fix(compiler): Use maps with symbolic offsets in memref casts for wrapper functions
The bufferization of the BConcrete dialect emits calls to Concrete wrapper functions and casts all memrefs to ranked memrefs with dynamic strides and an implicit identity layout map. The implicit identity map does not allow for casts of memrefs with non-zero offsets, e.g., resulting from folding of memrefs related to intermediate results passed as operands to the operation implemented by a wrapper. Casting to memrefs symbolic offsets in the layout map (e.g., `[d0, d1, ...](s0, s1, ...) -> (d0 + s0, d1 + s1, ...)`) allows for more flexibility, in particular this adds support for memrefs with non-zero, constant offsets returned by operations generating intermediate results.
This commit is contained in:
@@ -36,12 +36,31 @@ namespace {} // namespace
|
||||
|
||||
namespace {
|
||||
|
||||
// Returns a map with a symbolic offset for each dimension, i.e., for N
|
||||
// dimensions, it returns
|
||||
//
|
||||
// [d1, d2, ..., dN](s1, s2, ..., sN) -> (d1 + s1, d2 + s2, ..., dN + sN)
|
||||
//
|
||||
AffineMap getMultiDimSymbolicOffsetMap(mlir::RewriterBase &rewriter,
|
||||
unsigned rank) {
|
||||
SmallVector<AffineExpr, 4> dimExprs;
|
||||
dimExprs.reserve(rank);
|
||||
|
||||
for (unsigned i = 0; i < rank; ++i)
|
||||
dimExprs.push_back(rewriter.getAffineDimExpr(i) +
|
||||
rewriter.getAffineSymbolExpr(i));
|
||||
|
||||
return AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/rank, dimExprs,
|
||||
rewriter.getContext());
|
||||
}
|
||||
|
||||
mlir::Type getDynamicMemrefWithUnknownOffset(mlir::RewriterBase &rewriter,
|
||||
size_t rank) {
|
||||
mlir::MLIRContext *ctx = rewriter.getContext();
|
||||
std::vector<int64_t> shape(rank, -1);
|
||||
|
||||
return mlir::MemRefType::get(shape, rewriter.getI64Type(),
|
||||
rewriter.getMultiDimIdentityMap(rank));
|
||||
getMultiDimSymbolicOffsetMap(rewriter, rank));
|
||||
}
|
||||
|
||||
// Returns `memref.cast %0 : memref<...xAxT> to memref<...x?xT>`
|
||||
|
||||
Reference in New Issue
Block a user