[BACKEND] Added support to convert shared to distributed layouts (#1682)

This commit is contained in:
Zahi Moudallal
2023-05-17 17:20:29 -07:00
committed by GitHub
parent 9b072318bb
commit 34817ecc95
4 changed files with 200 additions and 16 deletions

View File

@@ -77,6 +77,10 @@ public:
dstLayout.isa<DotOperandEncodingAttr>()) {
return lowerMmaToDotOperand(op, adaptor, rewriter);
}
if (srcLayout.isa<SharedEncodingAttr>() &&
isaDistributedLayout(dstLayout)) {
return lowerSharedToDistributed(op, adaptor, rewriter);
}
// TODO: to be implemented
llvm_unreachable("unsupported layout conversion");
return failure();
@@ -482,9 +486,40 @@ private:
}
}
SmallVector<Type> types(outElems, llvmElemTy);
auto *ctx = llvmElemTy.getContext();
Type structTy = struct_ty(types);
Value result =
getTypeConverter()->packLLElements(loc, outVals, rewriter, dstTy);
rewriter.replaceOp(op, result);
return success();
}
LogicalResult
lowerSharedToDistributed(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
Value src = op.getSrc();
Value dst = op.getResult();
auto srcTy = src.getType().cast<RankedTensorType>();
auto srcShape = srcTy.getShape();
auto dstTy = dst.getType().cast<RankedTensorType>();
auto dstShape = dstTy.getShape();
assert(dstShape.size() == 2 &&
"Unexpected rank of ConvertLayout(shared->blocked)");
auto srcSharedLayout = srcTy.getEncoding().cast<SharedEncodingAttr>();
auto dstLayout = dstTy.getEncoding();
auto inOrd = getOrder(srcSharedLayout);
auto smemObj =
getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), rewriter);
auto elemTy = getTypeConverter()->convertType(dstTy.getElementType());
auto srcStrides =
getStridesFromShapeAndOrder(srcShape, inOrd, loc, rewriter);
auto dstIndices = emitIndices(loc, rewriter, dstLayout, dstTy);
SmallVector<Value> outVals = loadSharedToDistributed(
dst, dstIndices, src, smemObj, elemTy, loc, rewriter);
Value result =
getTypeConverter()->packLLElements(loc, outVals, rewriter, dstTy);
rewriter.replaceOp(op, result);

View File

@@ -359,6 +359,55 @@ public:
return ret;
}
SmallVector<Value>
loadSharedToDistributed(Value dst, ArrayRef<SmallVector<Value>> dstIndices,
Value src, SharedMemoryObject smemObj, Type elemTy,
Location loc,
ConversionPatternRewriter &rewriter) const {
auto dstTy = dst.getType().cast<RankedTensorType>();
auto dstShape = dstTy.getShape();
assert(dstShape.size() == 2 &&
"Unexpected rank of loadSharedToDistributed");
auto srcTy = src.getType().cast<RankedTensorType>();
auto dstDistributedLayout = dstTy.getEncoding();
if (auto mmaLayout = dstDistributedLayout.dyn_cast<MmaEncodingAttr>()) {
assert((!mmaLayout.isVolta()) &&
"ConvertLayout Shared->MMAv1 is not supported yet");
}
auto srcSharedLayout =
srcTy.getEncoding().cast<triton::gpu::SharedEncodingAttr>();
auto srcElemTy = srcTy.getElementType();
auto dstElemTy = dstTy.getElementType();
auto inOrd = triton::gpu::getOrder(srcSharedLayout);
auto outOrd = triton::gpu::getOrder(dstDistributedLayout);
unsigned outVec =
inOrd == outOrd
? triton::gpu::getContigPerThread(dstDistributedLayout)[outOrd[0]]
: 1;
unsigned inVec = srcSharedLayout.getVec();
unsigned minVec = std::min(outVec, inVec);
unsigned outElems = triton::gpu::getTotalElemsPerThread(dstTy);
assert(outElems == dstIndices.size());
DenseMap<unsigned, Value> sharedPtrs = getSwizzledSharedPtrs(
loc, outVec, dstTy, srcSharedLayout, srcElemTy, smemObj, rewriter,
smemObj.offsets, smemObj.strides);
assert(outElems % minVec == 0 && "Unexpected number of elements");
unsigned numVecs = outElems / minVec;
auto wordTy = vec_ty(elemTy, minVec);
SmallVector<Value> outVals(outElems);
for (unsigned i = 0; i < numVecs; ++i) {
Value smemAddr = sharedPtrs[i * minVec];
smemAddr = bitcast(smemAddr, ptr_ty(wordTy, 3));
Value valVec = load(smemAddr);
for (unsigned v = 0; v < minVec; ++v) {
Value currVal = extract_element(dstElemTy, valVec, i32_val(v));
outVals[i * minVec + v] = currVal;
}
}
return outVals;
}
void storeDistributedToShared(Value src, Value llSrc,
ArrayRef<Value> dstStrides,
ArrayRef<SmallVector<Value>> srcIndices,
@@ -386,16 +435,11 @@ public:
: 1;
unsigned outVec = dstSharedLayout.getVec();
unsigned minVec = std::min(outVec, inVec);
unsigned perPhase = dstSharedLayout.getPerPhase();
unsigned maxPhase = dstSharedLayout.getMaxPhase();
unsigned numElems = triton::gpu::getTotalElemsPerThread(srcTy);
assert(numElems == srcIndices.size());
auto inVals =
getTypeConverter()->unpackLLElements(loc, llSrc, rewriter, srcTy);
auto wordTy = vec_ty(elemTy, minVec);
auto elemPtrTy = ptr_ty(elemTy);
Value outVecVal = i32_val(outVec);
Value minVecVal = i32_val(minVec);
Value word;
SmallVector<Value> srcStrides = {dstStrides[0], dstStrides[1]};

View File

@@ -1,4 +1,6 @@
import numpy as np
import torch
from numpy.random import RandomState
import triton
import triton.language as tl
@@ -66,3 +68,69 @@ def test_chained_matmul():
block_k=block_k)
assert (torch_result == triton_result).all()
def test_vecmat():
@triton.jit
def batched_vecmat(
# inputs
A, # shape: [dim_m, dim_k]
B, # shape: [dim_m, dim_n, dim_k]
# dimensions
dim_m, dim_n, dim_k,
# outputs
output,
# block information
block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr
):
m_index = tl.program_id(0)
n_index = tl.program_id(1)
# Output tile
output_tile = (m_index * block_m + tl.arange(0, block_m))[:, None] * dim_n \
+ (n_index * block_n + tl.arange(0, block_n))[None, :]
vecmat = tl.zeros([block_m, block_n], dtype=A.dtype.element_ty)
k_blocks = dim_k // block_k
for k_index in range(k_blocks):
# Load A tile
a_tile = (m_index * block_m + tl.arange(0, block_m))[:, None] * dim_k \
+ (k_index * block_k + tl.arange(0, block_k))[None, :]
a = tl.load(A + a_tile)
# Load B tile, transposed to [n, m, k] in order to broadcast A on a
# leading dimension.
b_tile = (m_index * block_m + tl.arange(0, block_m))[None, :, None] * dim_n * dim_k \
+ (n_index * block_n + tl.arange(0, block_n))[:, None, None] * dim_k \
+ (k_index * block_k + tl.arange(0, block_k))[None, None, :]
b = tl.load(B + b_tile)
expanded_a, _ = tl.broadcast(a, b)
vecmat += tl.trans(tl.sum(expanded_a * b, axis=2))
tl.store(output + output_tile, vecmat)
M, N, K = 128, 128, 128
block_m, block_n, block_k = 16, 32, 64
rs = RandomState(17)
A_vec = rs.randint(0, 4, (M, K)).astype('float32')
B_vec = rs.randint(0, 4, (M, N, K)).astype('float32')
A = A_vec
B = B_vec
A_tri = torch.tensor(A, device='cuda')
B_tri = torch.tensor(B, device='cuda')
C_tri = torch.zeros((M, N), dtype=torch.float32, device='cuda')
grid = (M // block_m, N // block_n)
batched_vecmat[grid](A_tri, B_tri, M, N, K, C_tri,
block_m=block_m, block_n=block_n, block_k=block_k,
num_warps=4, num_stages=1)
A_expanded = A[:, np.newaxis, :]
A_broadcasted = np.broadcast_to(A_expanded, (M, N, K))
AB = A_broadcasted * B
C_ref = np.sum(AB, axis=2)
np.testing.assert_allclose(C_ref, C_tri.cpu().numpy(), rtol=0.01, atol=1e-3)

View File

@@ -130,6 +130,17 @@ class BlockedLayout:
return f"#triton_gpu.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}}}>"
class SharedLayout:
def __init__(self, vec, per_phase, max_phase, order):
self.vec = str(vec)
self.per_phase = str(per_phase)
self.max_phase = str(max_phase)
self.order = str(order)
def __str__(self):
return f"#triton_gpu.shared<{{vec={self.vec}, perPhase={self.per_phase}, maxPhase={self.max_phase}, order={self.order}}}>"
@pytest.mark.parametrize("dtype_x", list(dtypes) + ["bfloat16"])
def test_empty_kernel(dtype_x, device='cuda'):
SIZE = 128
@@ -2810,22 +2821,49 @@ layouts = [
BlockedLayout([4, 4], [1, 32], [4, 1], [1, 0])
]
intermediate_layouts = [
None,
SharedLayout(1, 1, 1, [1, 0]),
SharedLayout(4, 2, 4, [1, 0]),
SharedLayout(2, 2, 4, [1, 0]),
]
@pytest.mark.parametrize("shape", [(128, 128)])
@pytest.mark.parametrize("dtype", ['float16'])
@pytest.mark.parametrize("src_layout", layouts)
@pytest.mark.parametrize("interm_layout", intermediate_layouts)
@pytest.mark.parametrize("dst_layout", layouts)
def test_convert2d(dtype, shape, src_layout, dst_layout, device='cuda'):
def test_convert2d(dtype, shape, src_layout, interm_layout, dst_layout, device='cuda'):
if str(src_layout) == str(dst_layout):
pytest.skip()
if 'mma' in str(src_layout) and 'mma' in str(dst_layout):
pytest.skip()
ir = f"""
#src = {src_layout}
#dst = {dst_layout}
""" + """
module attributes {"triton_gpu.num-warps" = 4 : i32} {
layouts = f"""
#src = {src_layout}
#dst = {dst_layout}
""" if interm_layout is None else f"""
#src = {src_layout}
#interm = {interm_layout}
#dst = {dst_layout}
"""
conversion = f"""
%12 = triton_gpu.convert_layout %9 : (tensor<128x128xi32, #src>) -> tensor<128x128xi32, #dst>
%13 = triton_gpu.convert_layout %11 : (tensor<128x128xf16, #src>) -> tensor<128x128xf16, #dst>
""" if interm_layout is None else f"""
%15 = triton_gpu.convert_layout %9 : (tensor<128x128xi32, #src>) -> tensor<128x128xi32, #interm>
%16 = triton_gpu.convert_layout %15 : (tensor<128x128xi32, #interm>) -> tensor<128x128xi32, #src>
%17 = triton_gpu.convert_layout %11 : (tensor<128x128xf16, #src>) -> tensor<128x128xf16, #interm>
%18 = triton_gpu.convert_layout %17 : (tensor<128x128xf16, #interm>) -> tensor<128x128xf16, #src>
%12 = triton_gpu.convert_layout %16 : (tensor<128x128xi32, #src>) -> tensor<128x128xi32, #dst>
%13 = triton_gpu.convert_layout %18 : (tensor<128x128xf16, #src>) -> tensor<128x128xf16, #dst>
"""
ir = layouts + """
module attributes {"triton_gpu.num-warps" = 4 : i32} {
tt.func public @kernel_0d1d(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
%cst = arith.constant dense<128> : tensor<128x1xi32, #src>
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #src}>>
@@ -2840,8 +2878,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
%10 = tt.addptr %2, %9 : tensor<128x128x!tt.ptr<f16>, #src>, tensor<128x128xi32, #src>
%11 = tt.load %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf16, #src>
%3 = tt.splat %arg1 : (!tt.ptr<f16>) -> tensor<128x128x!tt.ptr<f16>, #dst>
%12 = triton_gpu.convert_layout %9 : (tensor<128x128xi32, #src>) -> tensor<128x128xi32, #dst>
%13 = triton_gpu.convert_layout %11 : (tensor<128x128xf16, #src>) -> tensor<128x128xf16, #dst>
""" + conversion + """
%14 = tt.addptr %3, %12 : tensor<128x128x!tt.ptr<f16>, #dst>, tensor<128x128xi32, #dst>
tt.store %14, %13 : tensor<128x128xf16, #dst>
tt.return