mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Convert layout illegal mem access fix (#2287)
This commit is contained in:
@@ -21,6 +21,7 @@ class AllocationAnalysis;
|
||||
SmallVector<unsigned>
|
||||
getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
|
||||
unsigned &outVec);
|
||||
SmallVector<unsigned> getRepShapeForCvtLayout(triton::gpu::ConvertLayoutOp op);
|
||||
|
||||
} // namespace triton
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ using ::mlir::triton::gpu::getOrder;
|
||||
using ::mlir::triton::gpu::getShapePerCTA;
|
||||
using ::mlir::triton::gpu::getShapePerCTATile;
|
||||
using ::mlir::triton::gpu::getSizePerThread;
|
||||
using ::mlir::triton::gpu::getUniqueContigPerThread;
|
||||
using ::mlir::triton::gpu::MmaEncodingAttr;
|
||||
using ::mlir::triton::gpu::SharedEncodingAttr;
|
||||
using ::mlir::triton::gpu::SliceEncodingAttr;
|
||||
@@ -50,9 +51,7 @@ getCvtOrder(Attribute srcLayout, Attribute dstLayout) {
|
||||
return {inOrd, outOrd};
|
||||
}
|
||||
|
||||
SmallVector<unsigned>
|
||||
getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
|
||||
unsigned &outVec) {
|
||||
SmallVector<unsigned> getRepShapeForCvtLayout(triton::gpu::ConvertLayoutOp op) {
|
||||
auto srcTy = op.getSrc().getType().cast<RankedTensorType>();
|
||||
auto dstTy = op.getResult().getType().cast<RankedTensorType>();
|
||||
Attribute srcLayout = srcTy.getEncoding();
|
||||
@@ -76,15 +75,7 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
|
||||
}
|
||||
}
|
||||
|
||||
assert(srcLayout && dstLayout &&
|
||||
"Unexpected layout in getScratchConfigForCvtLayout()");
|
||||
auto [inOrd, outOrd] = getCvtOrder(srcLayout, dstLayout);
|
||||
unsigned srcContigPerThread = getContigPerThread(srcLayout)[inOrd[0]];
|
||||
unsigned dstContigPerThread = getContigPerThread(dstLayout)[outOrd[0]];
|
||||
// TODO: Fix the legacy issue that ourOrd[0] == 0 always means
|
||||
// that we cannot do vectorization.
|
||||
inVec = outOrd[0] == 0 ? 1 : inOrd[0] == 0 ? 1 : srcContigPerThread;
|
||||
outVec = outOrd[0] == 0 ? 1 : dstContigPerThread;
|
||||
assert(srcLayout && dstLayout && "Unexpected layout in getRepShape()");
|
||||
|
||||
auto srcShapePerCTA = getShapePerCTA(srcTy);
|
||||
auto dstShapePerCTA = getShapePerCTA(dstTy);
|
||||
@@ -92,21 +83,44 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
|
||||
auto dstShapePerCTATile = getShapePerCTATile(dstLayout, dstTy.getShape());
|
||||
|
||||
unsigned rank = dstTy.getRank();
|
||||
SmallVector<unsigned> paddedRepShape(rank);
|
||||
unsigned pad = std::max(inVec, outVec);
|
||||
SmallVector<unsigned> repShape(rank);
|
||||
for (unsigned d = 0; d < rank; ++d) {
|
||||
paddedRepShape[d] =
|
||||
repShape[d] =
|
||||
std::max(std::min<unsigned>(srcShapePerCTA[d], srcShapePerCTATile[d]),
|
||||
std::min<unsigned>(dstShapePerCTA[d], dstShapePerCTATile[d]));
|
||||
}
|
||||
if (rank == 1)
|
||||
return paddedRepShape;
|
||||
return repShape;
|
||||
}
|
||||
|
||||
SmallVector<unsigned>
|
||||
getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
|
||||
unsigned &outVec) {
|
||||
auto repShape = getRepShapeForCvtLayout(op);
|
||||
|
||||
auto srcTy = op.getSrc().getType().cast<RankedTensorType>();
|
||||
auto dstTy = op.getResult().getType().cast<RankedTensorType>();
|
||||
Attribute srcLayout = srcTy.getEncoding();
|
||||
Attribute dstLayout = dstTy.getEncoding();
|
||||
|
||||
auto [inOrd, outOrd] = getCvtOrder(srcLayout, dstLayout);
|
||||
unsigned srcContigPerThread =
|
||||
getUniqueContigPerThread(srcLayout, srcTy.getShape())[inOrd[0]];
|
||||
unsigned dstContigPerThread =
|
||||
getUniqueContigPerThread(dstLayout, dstTy.getShape())[outOrd[0]];
|
||||
// TODO: Fix the legacy issue that ourOrd[0] == 0 always means
|
||||
// that we cannot do vectorization.
|
||||
inVec = outOrd[0] == 0 ? 1 : inOrd[0] == 0 ? 1 : srcContigPerThread;
|
||||
outVec = outOrd[0] == 0 ? 1 : dstContigPerThread;
|
||||
|
||||
if (repShape.size() <= 1)
|
||||
return repShape;
|
||||
unsigned paddedDim = 1;
|
||||
if (auto dstBlockedLayout = dstLayout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
paddedDim = dstBlockedLayout.getOrder()[0];
|
||||
}
|
||||
paddedRepShape[paddedDim] += pad;
|
||||
return paddedRepShape;
|
||||
unsigned pad = std::max(inVec, outVec);
|
||||
repShape[paddedDim] += pad;
|
||||
return repShape;
|
||||
}
|
||||
|
||||
SmallVector<unsigned>
|
||||
|
||||
@@ -237,12 +237,30 @@ private:
|
||||
llvm_unreachable("unexpected layout in getMultiDimOffset");
|
||||
}
|
||||
|
||||
SmallVector<Value>
|
||||
getWrappedMultiDimOffset(ConversionPatternRewriter &rewriter, Location loc,
|
||||
ArrayRef<Value> multiDimOffset,
|
||||
ArrayRef<unsigned> shape,
|
||||
SmallVector<unsigned> shapePerCTATile,
|
||||
SmallVector<int64_t> shapePerCTA) const {
|
||||
unsigned rank = shape.size();
|
||||
SmallVector<Value> multiDimOffsetWrapped(rank);
|
||||
for (unsigned d = 0; d < rank; ++d) {
|
||||
if (shapePerCTATile[d] > shapePerCTA[d])
|
||||
multiDimOffsetWrapped[d] = urem(multiDimOffset[d], i32_val(shape[d]));
|
||||
else
|
||||
multiDimOffsetWrapped[d] = multiDimOffset[d];
|
||||
}
|
||||
return multiDimOffsetWrapped;
|
||||
}
|
||||
|
||||
// shared memory rd/st for blocked or mma layout with data padding
|
||||
void processReplica(Location loc, ConversionPatternRewriter &rewriter,
|
||||
bool stNotRd, RankedTensorType type,
|
||||
ArrayRef<unsigned> numCTAsEachRep,
|
||||
ArrayRef<unsigned> multiDimRepId, unsigned vec,
|
||||
ArrayRef<unsigned> paddedRepShape,
|
||||
ArrayRef<unsigned> origRepShape,
|
||||
ArrayRef<unsigned> outOrd, SmallVector<Value> &vals,
|
||||
Value smemBase) const {
|
||||
auto accumNumCTAsEachRep = product<unsigned>(numCTAsEachRep);
|
||||
@@ -286,8 +304,11 @@ private:
|
||||
SmallVector<Value> multiDimOffset =
|
||||
getMultiDimOffset(layout, loc, rewriter, elemId, type,
|
||||
multiDimCTAInRepId, shapePerCTATile);
|
||||
Value offset =
|
||||
linearize(rewriter, loc, multiDimOffset, paddedRepShape, outOrd);
|
||||
SmallVector<Value> multiDimOffsetWrapped = getWrappedMultiDimOffset(
|
||||
rewriter, loc, multiDimOffset, origRepShape, shapePerCTATile,
|
||||
shapePerCTA);
|
||||
Value offset = linearize(rewriter, loc, multiDimOffsetWrapped,
|
||||
paddedRepShape, outOrd);
|
||||
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
|
||||
Value ptr = gep(elemPtrTy, smemBase, offset);
|
||||
auto vecTy = vec_ty(llvmElemTy, vec);
|
||||
@@ -575,6 +596,7 @@ private:
|
||||
rewriter, srcTy);
|
||||
unsigned inVec = 0;
|
||||
unsigned outVec = 0;
|
||||
auto origRepShape = getRepShapeForCvtLayout(op);
|
||||
auto paddedRepShape = getScratchConfigForCvtLayout(op, inVec, outVec);
|
||||
if (getElementTypeOrSelf(op.getType())
|
||||
.isa<mlir::Float8E4M3B11FNUZType, mlir::Float8E4M3FNType>()) {
|
||||
@@ -618,7 +640,7 @@ private:
|
||||
else
|
||||
processReplica(loc, rewriter, /*stNotRd*/ true, srcTy,
|
||||
inNumCTAsEachRep, multiDimRepId, inVec, paddedRepShape,
|
||||
outOrd, vals, smemBase);
|
||||
origRepShape, outOrd, vals, smemBase);
|
||||
} else {
|
||||
assert(0 && "ConvertLayout with input layout not implemented");
|
||||
return failure();
|
||||
@@ -651,7 +673,8 @@ private:
|
||||
else
|
||||
processReplica(loc, rewriter, /*stNotRd*/ false, dstTy,
|
||||
outNumCTAsEachRep, multiDimRepId, outVec,
|
||||
paddedRepShape, outOrd, outVals, smemBase);
|
||||
paddedRepShape, origRepShape, outOrd, outVals,
|
||||
smemBase);
|
||||
} else {
|
||||
assert(0 && "ConvertLayout with output layout not implemented");
|
||||
return failure();
|
||||
|
||||
@@ -339,6 +339,8 @@ public:
|
||||
// Order
|
||||
auto inOrder = triton::gpu::getOrder(srcEncoding);
|
||||
auto outOrder = triton::gpu::getOrder(resSharedLayout);
|
||||
assert(outVec * (maxPhase - 1) <= srcShape[outOrder[0]] &&
|
||||
"Swizzling would generate out of bounds memory accesses");
|
||||
// Tensor indices held by the current thread, as LLVM values
|
||||
auto srcIndices = emitIndices(loc, rewriter, srcEncoding, srcTy, false);
|
||||
// Swizzling with leading offsets (e.g. Hopper GMMA)
|
||||
@@ -452,10 +454,10 @@ public:
|
||||
auto dstElemTy = dstTy.getElementType();
|
||||
auto inOrd = triton::gpu::getOrder(srcSharedLayout);
|
||||
auto outOrd = triton::gpu::getOrder(dstDistributedLayout);
|
||||
unsigned outVec =
|
||||
inOrd == outOrd
|
||||
? triton::gpu::getContigPerThread(dstDistributedLayout)[outOrd[0]]
|
||||
: 1;
|
||||
unsigned outVec = inOrd == outOrd
|
||||
? triton::gpu::getUniqueContigPerThread(
|
||||
dstDistributedLayout, dstShape)[outOrd[0]]
|
||||
: 1;
|
||||
unsigned inVec = srcSharedLayout.getVec();
|
||||
unsigned minVec = std::min(outVec, inVec);
|
||||
unsigned outElems = triton::gpu::getTotalElemsPerThread(dstTy);
|
||||
@@ -501,10 +503,10 @@ public:
|
||||
auto dstElemTy = dstTy.getElementType();
|
||||
auto inOrd = triton::gpu::getOrder(srcDistributedLayout);
|
||||
auto outOrd = dstSharedLayout.getOrder();
|
||||
unsigned inVec =
|
||||
inOrd == outOrd
|
||||
? triton::gpu::getContigPerThread(srcDistributedLayout)[inOrd[0]]
|
||||
: 1;
|
||||
unsigned inVec = inOrd == outOrd
|
||||
? triton::gpu::getUniqueContigPerThread(
|
||||
srcDistributedLayout, srcShape)[inOrd[0]]
|
||||
: 1;
|
||||
unsigned outVec = dstSharedLayout.getVec();
|
||||
unsigned minVec = std::min(outVec, inVec);
|
||||
unsigned numElems = triton::gpu::getTotalElemsPerThread(srcTy);
|
||||
|
||||
@@ -3607,6 +3607,7 @@ layouts = [
|
||||
# MmaLayout((2, 0), [1, 4], [1, 1], [1, 1], [0, 1], [16, 8]),
|
||||
# MmaLayout(1, [4, 1], [1, 1], [0, 1]),
|
||||
# MmaLayout((2, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 8]),
|
||||
BlockedLayout([1, 16], [8, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
|
||||
BlockedLayout([1, 8], [2, 16], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
|
||||
BlockedLayout([1, 4], [4, 8], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]),
|
||||
BlockedLayout([1, 1], [1, 32], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]),
|
||||
@@ -3624,15 +3625,16 @@ intermediate_layouts = [
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("shape", [(128, 128)])
|
||||
@pytest.mark.parametrize("M, N", [[64, 1], [64, 64], [128, 128], [1, 64]])
|
||||
@pytest.mark.parametrize("dtype", ['float16'])
|
||||
@pytest.mark.parametrize("src_layout", layouts)
|
||||
@pytest.mark.parametrize("interm_layout", intermediate_layouts)
|
||||
@pytest.mark.parametrize("dst_layout", layouts)
|
||||
def test_convert2d(dtype, shape, src_layout, interm_layout, dst_layout, device):
|
||||
def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device):
|
||||
if is_hip():
|
||||
pytest.skip("test_convert2d is not supported in HIP")
|
||||
|
||||
if (M == 1 or N == 1) and interm_layout:
|
||||
pytest.skip("Out of bound access when maxPhase > 1")
|
||||
if str(src_layout) == str(dst_layout):
|
||||
pytest.skip()
|
||||
if 'mma' in str(src_layout) and 'mma' in str(dst_layout):
|
||||
@@ -3648,43 +3650,43 @@ def test_convert2d(dtype, shape, src_layout, interm_layout, dst_layout, device):
|
||||
"""
|
||||
|
||||
conversion = f"""
|
||||
%12 = triton_gpu.convert_layout %9 : (tensor<128x128xi32, #src>) -> tensor<128x128xi32, #dst>
|
||||
%13 = triton_gpu.convert_layout %11 : (tensor<128x128xf16, #src>) -> tensor<128x128xf16, #dst>
|
||||
%12 = triton_gpu.convert_layout %9 : (tensor<{M}x{N}xi32, #src>) -> tensor<{M}x{N}xi32, #dst>
|
||||
%13 = triton_gpu.convert_layout %11 : (tensor<{M}x{N}xf16, #src>) -> tensor<{M}x{N}xf16, #dst>
|
||||
""" if interm_layout is None else f"""
|
||||
%15 = triton_gpu.convert_layout %9 : (tensor<128x128xi32, #src>) -> tensor<128x128xi32, #interm>
|
||||
%16 = triton_gpu.convert_layout %15 : (tensor<128x128xi32, #interm>) -> tensor<128x128xi32, #src>
|
||||
%17 = triton_gpu.convert_layout %11 : (tensor<128x128xf16, #src>) -> tensor<128x128xf16, #interm>
|
||||
%18 = triton_gpu.convert_layout %17 : (tensor<128x128xf16, #interm>) -> tensor<128x128xf16, #src>
|
||||
%15 = triton_gpu.convert_layout %9 : (tensor<{M}x{N}xi32, #src>) -> tensor<{M}x{N}xi32, #interm>
|
||||
%16 = triton_gpu.convert_layout %15 : (tensor<{M}x{N}xi32, #interm>) -> tensor<{M}x{N}xi32, #src>
|
||||
%17 = triton_gpu.convert_layout %11 : (tensor<{M}x{N}xf16, #src>) -> tensor<{M}x{N}xf16, #interm>
|
||||
%18 = triton_gpu.convert_layout %17 : (tensor<{M}x{N}xf16, #interm>) -> tensor<{M}x{N}xf16, #src>
|
||||
|
||||
%12 = triton_gpu.convert_layout %16 : (tensor<128x128xi32, #src>) -> tensor<128x128xi32, #dst>
|
||||
%13 = triton_gpu.convert_layout %18 : (tensor<128x128xf16, #src>) -> tensor<128x128xf16, #dst>
|
||||
%12 = triton_gpu.convert_layout %16 : (tensor<{M}x{N}xi32, #src>) -> tensor<{M}x{N}xi32, #dst>
|
||||
%13 = triton_gpu.convert_layout %18 : (tensor<{M}x{N}xf16, #src>) -> tensor<{M}x{N}xf16, #dst>
|
||||
"""
|
||||
|
||||
ir = layouts + """
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
|
||||
tt.func public @kernel_0d1d(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
|
||||
%cst = arith.constant dense<128> : tensor<128x1xi32, #src>
|
||||
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #src}>>
|
||||
%1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #src}>>
|
||||
%2 = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<128x128x!tt.ptr<f16>, #src>
|
||||
%4 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #src}>>) -> tensor<128x1xi32, #src>
|
||||
%5 = arith.muli %4, %cst : tensor<128x1xi32, #src>
|
||||
%6 = tt.expand_dims %1 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #src}>>) -> tensor<1x128xi32, #src>
|
||||
%7 = tt.broadcast %6 : (tensor<1x128xi32, #src>) -> tensor<128x128xi32, #src>
|
||||
%8 = tt.broadcast %5 : (tensor<128x1xi32, #src>) -> tensor<128x128xi32, #src>
|
||||
%9 = arith.addi %8, %7 : tensor<128x128xi32, #src>
|
||||
%10 = tt.addptr %2, %9 : tensor<128x128x!tt.ptr<f16>, #src>, tensor<128x128xi32, #src>
|
||||
%11 = tt.load %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf16, #src>
|
||||
%3 = tt.splat %arg1 : (!tt.ptr<f16>) -> tensor<128x128x!tt.ptr<f16>, #dst>
|
||||
""" + conversion + """
|
||||
%14 = tt.addptr %3, %12 : tensor<128x128x!tt.ptr<f16>, #dst>, tensor<128x128xi32, #dst>
|
||||
tt.store %14, %13 : tensor<128x128xf16, #dst>
|
||||
ir = layouts + f"""
|
||||
module attributes {{"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32}} {{
|
||||
tt.func public @kernel_0d1d(%arg0: !tt.ptr<f16> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<f16> {{tt.divisibility = 16 : i32}}) {{
|
||||
%cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #src>
|
||||
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
|
||||
%1 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>
|
||||
%2 = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<{M}x{N}x!tt.ptr<f16>, #src>
|
||||
%4 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src>
|
||||
%5 = arith.muli %4, %cst : tensor<{M}x1xi32, #src>
|
||||
%6 = tt.expand_dims %1 {{axis = 0 : i32}} : (tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>) -> tensor<1x{N}xi32, #src>
|
||||
%7 = tt.broadcast %6 : (tensor<1x{N}xi32, #src>) -> tensor<{M}x{N}xi32, #src>
|
||||
%8 = tt.broadcast %5 : (tensor<{M}x1xi32, #src>) -> tensor<{M}x{N}xi32, #src>
|
||||
%9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #src>
|
||||
%10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr<f16>, #src>, tensor<{M}x{N}xi32, #src>
|
||||
%11 = tt.load %10 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xf16, #src>
|
||||
%3 = tt.splat %arg1 : (!tt.ptr<f16>) -> tensor<{M}x{N}x!tt.ptr<f16>, #dst>
|
||||
""" + conversion + f"""
|
||||
%14 = tt.addptr %3, %12 : tensor<{M}x{N}x!tt.ptr<f16>, #dst>, tensor<{M}x{N}xi32, #dst>
|
||||
tt.store %14, %13 : tensor<{M}x{N}xf16, #dst>
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
|
||||
x = to_triton(numpy_random(shape, dtype_str=dtype), device=device)
|
||||
x = to_triton(numpy_random((M, N), dtype_str=dtype), device=device)
|
||||
z = torch.empty_like(x)
|
||||
|
||||
# write the IR to a temporary file using mkstemp
|
||||
|
||||
Reference in New Issue
Block a user