mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Added support to convert shared to distributed layouts (#1682)
This commit is contained in:
@@ -77,6 +77,10 @@ public:
|
||||
dstLayout.isa<DotOperandEncodingAttr>()) {
|
||||
return lowerMmaToDotOperand(op, adaptor, rewriter);
|
||||
}
|
||||
if (srcLayout.isa<SharedEncodingAttr>() &&
|
||||
isaDistributedLayout(dstLayout)) {
|
||||
return lowerSharedToDistributed(op, adaptor, rewriter);
|
||||
}
|
||||
// TODO: to be implemented
|
||||
llvm_unreachable("unsupported layout conversion");
|
||||
return failure();
|
||||
@@ -482,9 +486,40 @@ private:
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<Type> types(outElems, llvmElemTy);
|
||||
auto *ctx = llvmElemTy.getContext();
|
||||
Type structTy = struct_ty(types);
|
||||
Value result =
|
||||
getTypeConverter()->packLLElements(loc, outVals, rewriter, dstTy);
|
||||
rewriter.replaceOp(op, result);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
lowerSharedToDistributed(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto loc = op.getLoc();
|
||||
Value src = op.getSrc();
|
||||
Value dst = op.getResult();
|
||||
auto srcTy = src.getType().cast<RankedTensorType>();
|
||||
auto srcShape = srcTy.getShape();
|
||||
auto dstTy = dst.getType().cast<RankedTensorType>();
|
||||
auto dstShape = dstTy.getShape();
|
||||
assert(dstShape.size() == 2 &&
|
||||
"Unexpected rank of ConvertLayout(shared->blocked)");
|
||||
auto srcSharedLayout = srcTy.getEncoding().cast<SharedEncodingAttr>();
|
||||
auto dstLayout = dstTy.getEncoding();
|
||||
auto inOrd = getOrder(srcSharedLayout);
|
||||
|
||||
auto smemObj =
|
||||
getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), rewriter);
|
||||
auto elemTy = getTypeConverter()->convertType(dstTy.getElementType());
|
||||
|
||||
auto srcStrides =
|
||||
getStridesFromShapeAndOrder(srcShape, inOrd, loc, rewriter);
|
||||
auto dstIndices = emitIndices(loc, rewriter, dstLayout, dstTy);
|
||||
|
||||
SmallVector<Value> outVals = loadSharedToDistributed(
|
||||
dst, dstIndices, src, smemObj, elemTy, loc, rewriter);
|
||||
|
||||
Value result =
|
||||
getTypeConverter()->packLLElements(loc, outVals, rewriter, dstTy);
|
||||
rewriter.replaceOp(op, result);
|
||||
|
||||
@@ -359,6 +359,55 @@ public:
|
||||
return ret;
|
||||
}
|
||||
|
||||
SmallVector<Value>
|
||||
loadSharedToDistributed(Value dst, ArrayRef<SmallVector<Value>> dstIndices,
|
||||
Value src, SharedMemoryObject smemObj, Type elemTy,
|
||||
Location loc,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto dstTy = dst.getType().cast<RankedTensorType>();
|
||||
auto dstShape = dstTy.getShape();
|
||||
assert(dstShape.size() == 2 &&
|
||||
"Unexpected rank of loadSharedToDistributed");
|
||||
auto srcTy = src.getType().cast<RankedTensorType>();
|
||||
auto dstDistributedLayout = dstTy.getEncoding();
|
||||
if (auto mmaLayout = dstDistributedLayout.dyn_cast<MmaEncodingAttr>()) {
|
||||
assert((!mmaLayout.isVolta()) &&
|
||||
"ConvertLayout Shared->MMAv1 is not supported yet");
|
||||
}
|
||||
auto srcSharedLayout =
|
||||
srcTy.getEncoding().cast<triton::gpu::SharedEncodingAttr>();
|
||||
auto srcElemTy = srcTy.getElementType();
|
||||
auto dstElemTy = dstTy.getElementType();
|
||||
auto inOrd = triton::gpu::getOrder(srcSharedLayout);
|
||||
auto outOrd = triton::gpu::getOrder(dstDistributedLayout);
|
||||
unsigned outVec =
|
||||
inOrd == outOrd
|
||||
? triton::gpu::getContigPerThread(dstDistributedLayout)[outOrd[0]]
|
||||
: 1;
|
||||
unsigned inVec = srcSharedLayout.getVec();
|
||||
unsigned minVec = std::min(outVec, inVec);
|
||||
unsigned outElems = triton::gpu::getTotalElemsPerThread(dstTy);
|
||||
assert(outElems == dstIndices.size());
|
||||
|
||||
DenseMap<unsigned, Value> sharedPtrs = getSwizzledSharedPtrs(
|
||||
loc, outVec, dstTy, srcSharedLayout, srcElemTy, smemObj, rewriter,
|
||||
smemObj.offsets, smemObj.strides);
|
||||
assert(outElems % minVec == 0 && "Unexpected number of elements");
|
||||
unsigned numVecs = outElems / minVec;
|
||||
auto wordTy = vec_ty(elemTy, minVec);
|
||||
SmallVector<Value> outVals(outElems);
|
||||
for (unsigned i = 0; i < numVecs; ++i) {
|
||||
Value smemAddr = sharedPtrs[i * minVec];
|
||||
smemAddr = bitcast(smemAddr, ptr_ty(wordTy, 3));
|
||||
Value valVec = load(smemAddr);
|
||||
for (unsigned v = 0; v < minVec; ++v) {
|
||||
Value currVal = extract_element(dstElemTy, valVec, i32_val(v));
|
||||
outVals[i * minVec + v] = currVal;
|
||||
}
|
||||
}
|
||||
return outVals;
|
||||
}
|
||||
|
||||
void storeDistributedToShared(Value src, Value llSrc,
|
||||
ArrayRef<Value> dstStrides,
|
||||
ArrayRef<SmallVector<Value>> srcIndices,
|
||||
@@ -386,16 +435,11 @@ public:
|
||||
: 1;
|
||||
unsigned outVec = dstSharedLayout.getVec();
|
||||
unsigned minVec = std::min(outVec, inVec);
|
||||
unsigned perPhase = dstSharedLayout.getPerPhase();
|
||||
unsigned maxPhase = dstSharedLayout.getMaxPhase();
|
||||
unsigned numElems = triton::gpu::getTotalElemsPerThread(srcTy);
|
||||
assert(numElems == srcIndices.size());
|
||||
auto inVals =
|
||||
getTypeConverter()->unpackLLElements(loc, llSrc, rewriter, srcTy);
|
||||
auto wordTy = vec_ty(elemTy, minVec);
|
||||
auto elemPtrTy = ptr_ty(elemTy);
|
||||
Value outVecVal = i32_val(outVec);
|
||||
Value minVecVal = i32_val(minVec);
|
||||
Value word;
|
||||
|
||||
SmallVector<Value> srcStrides = {dstStrides[0], dstStrides[1]};
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from numpy.random import RandomState
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
@@ -66,3 +68,69 @@ def test_chained_matmul():
|
||||
block_k=block_k)
|
||||
|
||||
assert (torch_result == triton_result).all()
|
||||
|
||||
|
||||
def test_vecmat():
|
||||
@triton.jit
|
||||
def batched_vecmat(
|
||||
# inputs
|
||||
A, # shape: [dim_m, dim_k]
|
||||
B, # shape: [dim_m, dim_n, dim_k]
|
||||
# dimensions
|
||||
dim_m, dim_n, dim_k,
|
||||
# outputs
|
||||
output,
|
||||
# block information
|
||||
block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr
|
||||
):
|
||||
m_index = tl.program_id(0)
|
||||
n_index = tl.program_id(1)
|
||||
# Output tile
|
||||
output_tile = (m_index * block_m + tl.arange(0, block_m))[:, None] * dim_n \
|
||||
+ (n_index * block_n + tl.arange(0, block_n))[None, :]
|
||||
|
||||
vecmat = tl.zeros([block_m, block_n], dtype=A.dtype.element_ty)
|
||||
k_blocks = dim_k // block_k
|
||||
for k_index in range(k_blocks):
|
||||
# Load A tile
|
||||
a_tile = (m_index * block_m + tl.arange(0, block_m))[:, None] * dim_k \
|
||||
+ (k_index * block_k + tl.arange(0, block_k))[None, :]
|
||||
a = tl.load(A + a_tile)
|
||||
|
||||
# Load B tile, transposed to [n, m, k] in order to broadcast A on a
|
||||
# leading dimension.
|
||||
b_tile = (m_index * block_m + tl.arange(0, block_m))[None, :, None] * dim_n * dim_k \
|
||||
+ (n_index * block_n + tl.arange(0, block_n))[:, None, None] * dim_k \
|
||||
+ (k_index * block_k + tl.arange(0, block_k))[None, None, :]
|
||||
b = tl.load(B + b_tile)
|
||||
|
||||
expanded_a, _ = tl.broadcast(a, b)
|
||||
vecmat += tl.trans(tl.sum(expanded_a * b, axis=2))
|
||||
|
||||
tl.store(output + output_tile, vecmat)
|
||||
|
||||
M, N, K = 128, 128, 128
|
||||
block_m, block_n, block_k = 16, 32, 64
|
||||
|
||||
rs = RandomState(17)
|
||||
A_vec = rs.randint(0, 4, (M, K)).astype('float32')
|
||||
B_vec = rs.randint(0, 4, (M, N, K)).astype('float32')
|
||||
A = A_vec
|
||||
B = B_vec
|
||||
|
||||
A_tri = torch.tensor(A, device='cuda')
|
||||
B_tri = torch.tensor(B, device='cuda')
|
||||
C_tri = torch.zeros((M, N), dtype=torch.float32, device='cuda')
|
||||
|
||||
grid = (M // block_m, N // block_n)
|
||||
|
||||
batched_vecmat[grid](A_tri, B_tri, M, N, K, C_tri,
|
||||
block_m=block_m, block_n=block_n, block_k=block_k,
|
||||
num_warps=4, num_stages=1)
|
||||
|
||||
A_expanded = A[:, np.newaxis, :]
|
||||
A_broadcasted = np.broadcast_to(A_expanded, (M, N, K))
|
||||
AB = A_broadcasted * B
|
||||
C_ref = np.sum(AB, axis=2)
|
||||
|
||||
np.testing.assert_allclose(C_ref, C_tri.cpu().numpy(), rtol=0.01, atol=1e-3)
|
||||
|
||||
@@ -130,6 +130,17 @@ class BlockedLayout:
|
||||
return f"#triton_gpu.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}}}>"
|
||||
|
||||
|
||||
class SharedLayout:
|
||||
def __init__(self, vec, per_phase, max_phase, order):
|
||||
self.vec = str(vec)
|
||||
self.per_phase = str(per_phase)
|
||||
self.max_phase = str(max_phase)
|
||||
self.order = str(order)
|
||||
|
||||
def __str__(self):
|
||||
return f"#triton_gpu.shared<{{vec={self.vec}, perPhase={self.per_phase}, maxPhase={self.max_phase}, order={self.order}}}>"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_x", list(dtypes) + ["bfloat16"])
|
||||
def test_empty_kernel(dtype_x, device='cuda'):
|
||||
SIZE = 128
|
||||
@@ -2810,22 +2821,49 @@ layouts = [
|
||||
BlockedLayout([4, 4], [1, 32], [4, 1], [1, 0])
|
||||
]
|
||||
|
||||
intermediate_layouts = [
|
||||
None,
|
||||
SharedLayout(1, 1, 1, [1, 0]),
|
||||
SharedLayout(4, 2, 4, [1, 0]),
|
||||
SharedLayout(2, 2, 4, [1, 0]),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("shape", [(128, 128)])
|
||||
@pytest.mark.parametrize("dtype", ['float16'])
|
||||
@pytest.mark.parametrize("src_layout", layouts)
|
||||
@pytest.mark.parametrize("interm_layout", intermediate_layouts)
|
||||
@pytest.mark.parametrize("dst_layout", layouts)
|
||||
def test_convert2d(dtype, shape, src_layout, dst_layout, device='cuda'):
|
||||
def test_convert2d(dtype, shape, src_layout, interm_layout, dst_layout, device='cuda'):
|
||||
if str(src_layout) == str(dst_layout):
|
||||
pytest.skip()
|
||||
if 'mma' in str(src_layout) and 'mma' in str(dst_layout):
|
||||
pytest.skip()
|
||||
|
||||
ir = f"""
|
||||
#src = {src_layout}
|
||||
#dst = {dst_layout}
|
||||
""" + """
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
layouts = f"""
|
||||
#src = {src_layout}
|
||||
#dst = {dst_layout}
|
||||
""" if interm_layout is None else f"""
|
||||
#src = {src_layout}
|
||||
#interm = {interm_layout}
|
||||
#dst = {dst_layout}
|
||||
"""
|
||||
|
||||
conversion = f"""
|
||||
%12 = triton_gpu.convert_layout %9 : (tensor<128x128xi32, #src>) -> tensor<128x128xi32, #dst>
|
||||
%13 = triton_gpu.convert_layout %11 : (tensor<128x128xf16, #src>) -> tensor<128x128xf16, #dst>
|
||||
""" if interm_layout is None else f"""
|
||||
%15 = triton_gpu.convert_layout %9 : (tensor<128x128xi32, #src>) -> tensor<128x128xi32, #interm>
|
||||
%16 = triton_gpu.convert_layout %15 : (tensor<128x128xi32, #interm>) -> tensor<128x128xi32, #src>
|
||||
%17 = triton_gpu.convert_layout %11 : (tensor<128x128xf16, #src>) -> tensor<128x128xf16, #interm>
|
||||
%18 = triton_gpu.convert_layout %17 : (tensor<128x128xf16, #interm>) -> tensor<128x128xf16, #src>
|
||||
|
||||
%12 = triton_gpu.convert_layout %16 : (tensor<128x128xi32, #src>) -> tensor<128x128xi32, #dst>
|
||||
%13 = triton_gpu.convert_layout %18 : (tensor<128x128xf16, #src>) -> tensor<128x128xf16, #dst>
|
||||
"""
|
||||
|
||||
ir = layouts + """
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
tt.func public @kernel_0d1d(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
|
||||
%cst = arith.constant dense<128> : tensor<128x1xi32, #src>
|
||||
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #src}>>
|
||||
@@ -2840,8 +2878,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
%10 = tt.addptr %2, %9 : tensor<128x128x!tt.ptr<f16>, #src>, tensor<128x128xi32, #src>
|
||||
%11 = tt.load %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf16, #src>
|
||||
%3 = tt.splat %arg1 : (!tt.ptr<f16>) -> tensor<128x128x!tt.ptr<f16>, #dst>
|
||||
%12 = triton_gpu.convert_layout %9 : (tensor<128x128xi32, #src>) -> tensor<128x128xi32, #dst>
|
||||
%13 = triton_gpu.convert_layout %11 : (tensor<128x128xf16, #src>) -> tensor<128x128xf16, #dst>
|
||||
""" + conversion + """
|
||||
%14 = tt.addptr %3, %12 : tensor<128x128x!tt.ptr<f16>, #dst>, tensor<128x128xi32, #dst>
|
||||
tt.store %14, %13 : tensor<128x128xf16, #dst>
|
||||
tt.return
|
||||
|
||||
Reference in New Issue
Block a user