ROCM IFU: Enable slice layout for insertSliceAsync AMD path

Fix basic_insert_slice_async_1d lit test

Remove code added for debugging

Return hopper test
This commit is contained in:
Ognjen
2023-11-15 00:23:35 +00:00
committed by Jason Furmanek
parent 484852876e
commit 38fbb7e472
4 changed files with 59 additions and 29 deletions

View File

@@ -1468,7 +1468,9 @@ struct InsertSliceOpConversion
"Only support in-place insert_slice for now");
auto srcTy = src.getType().dyn_cast<RankedTensorType>();
auto srcLayout = srcTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
auto srcLayout = srcTy.getEncoding();
assert((srcLayout.isa<BlockedEncodingAttr, SliceEncodingAttr>() &&
"Unexpected srcLayout in InsertSliceOpConversion"));
auto srcShape = srcTy.getShape();
assert(srcLayout && "Unexpected srcLayout in InsertSliceOpConversion");

View File

@@ -512,7 +512,7 @@ public:
ConversionPatternRewriter &rewriter) const {
auto srcTy = src.getType().cast<RankedTensorType>();
auto srcShape = srcTy.getShape();
assert(srcShape.size() == 2 &&
assert((srcShape.size() == 1 || srcShape.size() == 2) &&
"Unexpected rank of storeDistributedToShared");
auto dstTy = dst.getType().cast<RankedTensorType>();
auto srcDistributedLayout = srcTy.getEncoding();
@@ -538,8 +538,12 @@ public:
auto wordTy = vec_ty(elemTy, minVec);
Value word;
SmallVector<Value> srcStrides = {dstStrides[0], dstStrides[1]};
SmallVector<Value> offsetVals = {i32_val(0), i32_val(0)};
SmallVector<Value> srcStrides;
SmallVector<Value> offsetVals;
for (int i = 0; i < srcShape.size(); i++) {
srcStrides.push_back(dstStrides[i]);
offsetVals.push_back(i32_val(0));
}
SharedMemoryObject smemObj(smemBase, srcStrides, offsetVals);
DenseMap<unsigned, Value> sharedPtrs =

View File

@@ -935,8 +935,10 @@ private:
auto mask = insertSliceAsyncOp.getMask();
auto srcTy = src.getType().cast<RankedTensorType>();
auto dstTy = dst.getType().cast<RankedTensorType>();
auto srcBlocked =
srcTy.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>();
auto srcLayout = srcTy.getEncoding();
assert((srcLayout.isa<BlockedEncodingAttr, SliceEncodingAttr>() &&
"Unexpected srcLayout"));
auto resSharedLayout =
dstTy.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>();
auto resElemTy = dstTy.getElementType();
@@ -966,7 +968,7 @@ private:
// load
auto tmpTy =
RankedTensorType::get(srcTy.getShape(), resElemTy, srcBlocked);
RankedTensorType::get(srcTy.getShape(), resElemTy, srcLayout);
auto loadOp = builder.create<triton::LoadOp>(
insertSliceAsyncOp.getLoc(), tmpTy, insertSliceAsyncOp.getSrc(),
insertSliceAsyncOp.getMask(), insertSliceAsyncOp.getOther(),
@@ -999,8 +1001,12 @@ private:
});
mod.walk([&](triton::gpu::AsyncCommitGroupOp asyncCommitGroupOp) -> void {
#ifdef USE_ROCM
asyncCommitGroupOp.erase();
#else
if (!triton::gpu::AsyncCommitGroupOp::isSupported(computeCapability))
asyncCommitGroupOp.erase();
#endif
});
mod.walk([&](triton::gpu::AsyncWaitOp asyncWaitOp) -> void {