[BACKEND] Modified store op thread masking (#1605)

This commit is contained in:
Zahi Moudallal
2023-05-04 17:15:05 -07:00
committed by GitHub
parent deb2c71fb4
commit e2ae2c6c48
4 changed files with 88 additions and 9 deletions

View File

@@ -298,14 +298,7 @@ struct StoreOpConversion
vec = std::min(vec, maskAlign);
}
// numElements = 1 for scalar
auto tensorTy = valueTy.dyn_cast<RankedTensorType>();
auto numElems = tensorTy ? tensorTy.getNumElements() : 1;
Value mask = int_val(1, 1);
auto tid = tid_val();
mask = and_(mask,
icmp_slt(mul(tid, i32_val(elemsPerThread)), i32_val(numElems)));
Value mask = getMask(valueTy, rewriter, loc);
const size_t dtsize =
std::max<int>(1, valueElemTy.getIntOrFloatBitWidth() / 8);
const size_t valueElemNBits = dtsize * 8;

View File

@@ -421,6 +421,46 @@ public:
// -----------------------------------------------------------------------
// Utilities
// -----------------------------------------------------------------------
Value getMask(Type valueTy, ConversionPatternRewriter &rewriter,
Location loc) const {
auto tensorTy = valueTy.dyn_cast<RankedTensorType>();
Value mask = int_val(1, 1);
auto tid = tid_val();
if (tensorTy) {
auto layout = tensorTy.getEncoding();
auto shape = tensorTy.getShape();
unsigned rank = shape.size();
auto sizePerThread = triton::gpu::getSizePerThread(layout);
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(layout);
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(layout);
auto order = triton::gpu::getOrder(layout);
auto shapePerCTA = triton::gpu::getShapePerCTA(layout, shape);
Value warpSize = i32_val(32);
Value laneId = urem(tid, warpSize);
Value warpId = udiv(tid, warpSize);
SmallVector<Value> multiDimWarpId =
delinearize(rewriter, loc, warpId, warpsPerCTA, order);
SmallVector<Value> multiDimThreadId =
delinearize(rewriter, loc, laneId, threadsPerWarp, order);
for (unsigned dim = 0; dim < rank; ++dim) {
// if there is no data replication across threads on this dimension
if (shape[dim] >= shapePerCTA[dim])
continue;
// Otherwise, we need to mask threads that will replicate data on this
// dimension. Calculate the thread index on this dimension for the CTA
Value threadDim =
add(mul(multiDimWarpId[dim], i32_val(threadsPerWarp[dim])),
multiDimThreadId[dim]);
mask = and_(mask, icmp_slt(mul(threadDim, i32_val(sizePerThread[dim])),
i32_val(shape[dim])));
}
} else {
// If the tensor is not ranked, then it is a scalar and only thread 0 can
// write
mask = and_(mask, icmp_slt(tid, i32_val(1)));
}
return mask;
}
// Convert an \param index to a multi-dim coordinate given \param shape and
// \param order.

View File

@@ -1488,6 +1488,53 @@ def test_reduce_layouts(M, N, src_layout, axis, device='cuda'):
np.testing.assert_allclose(z_ref, z_tri.cpu().numpy(), rtol=0.01, atol=1e-3)
layouts = [
BlockedLayout([1, 4], [1, 32], [4, 1], [1, 0]),
BlockedLayout([1, 4], [1, 32], [2, 2], [1, 0]),
MmaLayout(version=(2, 0), warps_per_cta=[4, 1])
]
@pytest.mark.parametrize("M", [32, 64, 128, 256])
@pytest.mark.parametrize("src_layout", layouts)
def test_store_op(M, src_layout, device='cuda'):
ir = f"""
#src = {src_layout}
module attributes {{"triton_gpu.num-warps" = 4 : i32}} {{
tt.func public @kernel(%arg0: !tt.ptr<f32> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<f32> {{tt.divisibility = 16 : i32}}) {{
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
%1 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<{M}x!tt.ptr<f32>, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
%2 = tt.addptr %1, %0 : tensor<{M}x!tt.ptr<f32>, #triton_gpu.slice<{{dim = 1, parent = #src}}>>, tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
%3 = tt.load %2 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}xf32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
%4 = tt.expand_dims %3 {{axis = 1 : i32}} : (tensor<{M}xf32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xf32, #src>
%5 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
%6 = tt.expand_dims %5 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src>
%7 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<{M}x1x!tt.ptr<f32>, #src>
%8 = tt.addptr %7, %6 : tensor<{M}x1x!tt.ptr<f32>, #src>, tensor<{M}x1xi32, #src>
tt.store %8, %4 : tensor<{M}x1xf32, #src>
tt.return
}}
}}
"""
import tempfile
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
f.write(ir)
f.flush()
store_kernel = triton.compile(f.name)
rs = RandomState(17)
x = rs.randint(0, 4, (M, 1)).astype('float32')
y = np.zeros((M, 1), dtype='float32')
x_tri = torch.tensor(x, device=device)
y_tri = torch.tensor(y, device=device)
pgm = store_kernel[(1, 1, 1)](x_tri, y_tri)
y_ref = x
np.testing.assert_allclose(y_ref, y_tri.cpu().numpy(), rtol=0.01, atol=1e-3)
@triton.jit
def _welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2):
delta = mean_2 - mean_1

View File

@@ -1038,7 +1038,6 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: store_f32
tt.func @store_f32(%arg0 : tensor<256x!tt.ptr<f32>, #blocked0>, %arg1 : tensor<256xf32, #blocked0>) {
// CHECK: llvm.icmp "slt"
// CHECK: llvm.inline_asm
// CHECK-SAME: @$2 st.global.b32
// CHECK: llvm.inline_asm