mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Merge commit '721897fcc4f942aa97d2e9ba3787a5e213758177' into ifu-231108
Conflicts: bin/triton-translate.cpp lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp python/triton/compiler/compiler.py python/triton/runtime/jit.py python/tutorials/06-fused-attention.py test/Conversion/tritongpu_to_llvm.mlir
This commit is contained in:
@@ -1549,10 +1549,12 @@ struct InsertSliceAsyncOpConversion
|
||||
auto srcTy = src.getType().cast<RankedTensorType>();
|
||||
auto resTy = dst.getType().cast<RankedTensorType>();
|
||||
auto resElemTy = getTypeConverter()->convertType(resTy.getElementType());
|
||||
auto srcBlockedLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
|
||||
auto srcLayout = srcTy.getEncoding();
|
||||
assert((srcLayout.isa<BlockedEncodingAttr, SliceEncodingAttr>() &&
|
||||
"Unexpected srcLayout in InsertSliceAsyncOpConversion"));
|
||||
auto resSharedLayout = resTy.getEncoding().cast<SharedEncodingAttr>();
|
||||
auto srcShape = srcTy.getShape();
|
||||
assert(srcShape.size() == 2 &&
|
||||
assert((srcShape.size() == 1 || srcShape.size() == 2) &&
|
||||
"insert_slice_async: Unexpected rank of %src");
|
||||
|
||||
Value llDst = adaptor.getDst();
|
||||
@@ -1617,25 +1619,15 @@ struct InsertSliceAsyncOpConversion
|
||||
unsigned numElems = getTotalElemsPerThread(srcTy);
|
||||
unsigned perPhase = resSharedLayout.getPerPhase();
|
||||
unsigned maxPhase = resSharedLayout.getMaxPhase();
|
||||
auto sizePerThread = srcBlockedLayout.getSizePerThread();
|
||||
auto threadsPerCTA = getThreadsPerCTA(srcBlockedLayout);
|
||||
auto inOrder = srcBlockedLayout.getOrder();
|
||||
DenseMap<unsigned, Value> sharedPtrs =
|
||||
getSwizzledSharedPtrs(loc, inVec, srcTy, resSharedLayout, resElemTy,
|
||||
smemObj, rewriter, offsetVals, srcStrides);
|
||||
|
||||
// If perPhase * maxPhase > threadsPerCTA, we will have elements
|
||||
// that share the same tile indices. The index calculation will
|
||||
// be cached.
|
||||
auto numSwizzleRows = std::max<unsigned>(
|
||||
(perPhase * maxPhase) / threadsPerCTA[inOrder[1]], 1);
|
||||
// A sharedLayout encoding has a "vec" parameter.
|
||||
// On the column dimension, if inVec > outVec, it means we have to divide
|
||||
// single vector read into multiple ones
|
||||
auto numVecCols = std::max<unsigned>(inVec / outVec, 1);
|
||||
|
||||
auto srcIndices = emitIndices(loc, rewriter, srcBlockedLayout, srcTy);
|
||||
|
||||
for (unsigned elemIdx = 0; elemIdx < numElems; elemIdx += minVec) {
|
||||
// 16 * 8 = 128bits
|
||||
auto maxBitWidth =
|
||||
|
||||
Reference in New Issue
Block a user