mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
ROCM IFU: Enable slice layout for insertSliceAsync AMD path
Fix basic_insert_slice_async_1d lit test Remove code added for debugging Return hopper test
This commit is contained in:
@@ -1468,7 +1468,9 @@ struct InsertSliceOpConversion
|
||||
"Only support in-place insert_slice for now");
|
||||
|
||||
auto srcTy = src.getType().dyn_cast<RankedTensorType>();
|
||||
auto srcLayout = srcTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
|
||||
auto srcLayout = srcTy.getEncoding();
|
||||
assert((srcLayout.isa<BlockedEncodingAttr, SliceEncodingAttr>() &&
|
||||
"Unexpected srcLayout in InsertSliceOpConversion"));
|
||||
auto srcShape = srcTy.getShape();
|
||||
assert(srcLayout && "Unexpected srcLayout in InsertSliceOpConversion");
|
||||
|
||||
|
||||
@@ -512,7 +512,7 @@ public:
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto srcTy = src.getType().cast<RankedTensorType>();
|
||||
auto srcShape = srcTy.getShape();
|
||||
assert(srcShape.size() == 2 &&
|
||||
assert((srcShape.size() == 1 || srcShape.size() == 2) &&
|
||||
"Unexpected rank of storeDistributedToShared");
|
||||
auto dstTy = dst.getType().cast<RankedTensorType>();
|
||||
auto srcDistributedLayout = srcTy.getEncoding();
|
||||
@@ -538,8 +538,12 @@ public:
|
||||
auto wordTy = vec_ty(elemTy, minVec);
|
||||
Value word;
|
||||
|
||||
SmallVector<Value> srcStrides = {dstStrides[0], dstStrides[1]};
|
||||
SmallVector<Value> offsetVals = {i32_val(0), i32_val(0)};
|
||||
SmallVector<Value> srcStrides;
|
||||
SmallVector<Value> offsetVals;
|
||||
for (int i = 0; i < srcShape.size(); i++) {
|
||||
srcStrides.push_back(dstStrides[i]);
|
||||
offsetVals.push_back(i32_val(0));
|
||||
}
|
||||
SharedMemoryObject smemObj(smemBase, srcStrides, offsetVals);
|
||||
|
||||
DenseMap<unsigned, Value> sharedPtrs =
|
||||
|
||||
@@ -935,8 +935,10 @@ private:
|
||||
auto mask = insertSliceAsyncOp.getMask();
|
||||
auto srcTy = src.getType().cast<RankedTensorType>();
|
||||
auto dstTy = dst.getType().cast<RankedTensorType>();
|
||||
auto srcBlocked =
|
||||
srcTy.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>();
|
||||
auto srcLayout = srcTy.getEncoding();
|
||||
assert((srcLayout.isa<BlockedEncodingAttr, SliceEncodingAttr>() &&
|
||||
"Unexpected srcLayout"));
|
||||
|
||||
auto resSharedLayout =
|
||||
dstTy.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>();
|
||||
auto resElemTy = dstTy.getElementType();
|
||||
@@ -966,7 +968,7 @@ private:
|
||||
|
||||
// load
|
||||
auto tmpTy =
|
||||
RankedTensorType::get(srcTy.getShape(), resElemTy, srcBlocked);
|
||||
RankedTensorType::get(srcTy.getShape(), resElemTy, srcLayout);
|
||||
auto loadOp = builder.create<triton::LoadOp>(
|
||||
insertSliceAsyncOp.getLoc(), tmpTy, insertSliceAsyncOp.getSrc(),
|
||||
insertSliceAsyncOp.getMask(), insertSliceAsyncOp.getOther(),
|
||||
@@ -999,8 +1001,12 @@ private:
|
||||
});
|
||||
|
||||
mod.walk([&](triton::gpu::AsyncCommitGroupOp asyncCommitGroupOp) -> void {
|
||||
#ifdef USE_ROCM
|
||||
asyncCommitGroupOp.erase();
|
||||
#else
|
||||
if (!triton::gpu::AsyncCommitGroupOp::isSupported(computeCapability))
|
||||
asyncCommitGroupOp.erase();
|
||||
#endif
|
||||
});
|
||||
|
||||
mod.walk([&](triton::gpu::AsyncWaitOp asyncWaitOp) -> void {
|
||||
|
||||
Reference in New Issue
Block a user