mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Modified store op thread masking (#1605)
This commit is contained in:
@@ -298,14 +298,7 @@ struct StoreOpConversion
|
||||
vec = std::min(vec, maskAlign);
|
||||
}
|
||||
|
||||
// numElements = 1 for scalar
|
||||
auto tensorTy = valueTy.dyn_cast<RankedTensorType>();
|
||||
auto numElems = tensorTy ? tensorTy.getNumElements() : 1;
|
||||
Value mask = int_val(1, 1);
|
||||
auto tid = tid_val();
|
||||
mask = and_(mask,
|
||||
icmp_slt(mul(tid, i32_val(elemsPerThread)), i32_val(numElems)));
|
||||
|
||||
Value mask = getMask(valueTy, rewriter, loc);
|
||||
const size_t dtsize =
|
||||
std::max<int>(1, valueElemTy.getIntOrFloatBitWidth() / 8);
|
||||
const size_t valueElemNBits = dtsize * 8;
|
||||
|
||||
@@ -421,6 +421,46 @@ public:
|
||||
// -----------------------------------------------------------------------
|
||||
// Utilities
|
||||
// -----------------------------------------------------------------------
|
||||
Value getMask(Type valueTy, ConversionPatternRewriter &rewriter,
|
||||
Location loc) const {
|
||||
auto tensorTy = valueTy.dyn_cast<RankedTensorType>();
|
||||
Value mask = int_val(1, 1);
|
||||
auto tid = tid_val();
|
||||
if (tensorTy) {
|
||||
auto layout = tensorTy.getEncoding();
|
||||
auto shape = tensorTy.getShape();
|
||||
unsigned rank = shape.size();
|
||||
auto sizePerThread = triton::gpu::getSizePerThread(layout);
|
||||
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(layout);
|
||||
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(layout);
|
||||
auto order = triton::gpu::getOrder(layout);
|
||||
auto shapePerCTA = triton::gpu::getShapePerCTA(layout, shape);
|
||||
Value warpSize = i32_val(32);
|
||||
Value laneId = urem(tid, warpSize);
|
||||
Value warpId = udiv(tid, warpSize);
|
||||
SmallVector<Value> multiDimWarpId =
|
||||
delinearize(rewriter, loc, warpId, warpsPerCTA, order);
|
||||
SmallVector<Value> multiDimThreadId =
|
||||
delinearize(rewriter, loc, laneId, threadsPerWarp, order);
|
||||
for (unsigned dim = 0; dim < rank; ++dim) {
|
||||
// if there is no data replication across threads on this dimension
|
||||
if (shape[dim] >= shapePerCTA[dim])
|
||||
continue;
|
||||
// Otherwise, we need to mask threads that will replicate data on this
|
||||
// dimension. Calculate the thread index on this dimension for the CTA
|
||||
Value threadDim =
|
||||
add(mul(multiDimWarpId[dim], i32_val(threadsPerWarp[dim])),
|
||||
multiDimThreadId[dim]);
|
||||
mask = and_(mask, icmp_slt(mul(threadDim, i32_val(sizePerThread[dim])),
|
||||
i32_val(shape[dim])));
|
||||
}
|
||||
} else {
|
||||
// If the tensor is not ranked, then it is a scalar and only thread 0 can
|
||||
// write
|
||||
mask = and_(mask, icmp_slt(tid, i32_val(1)));
|
||||
}
|
||||
return mask;
|
||||
}
|
||||
|
||||
// Convert an \param index to a multi-dim coordinate given \param shape and
|
||||
// \param order.
|
||||
|
||||
@@ -1488,6 +1488,53 @@ def test_reduce_layouts(M, N, src_layout, axis, device='cuda'):
|
||||
np.testing.assert_allclose(z_ref, z_tri.cpu().numpy(), rtol=0.01, atol=1e-3)
|
||||
|
||||
|
||||
layouts = [
|
||||
BlockedLayout([1, 4], [1, 32], [4, 1], [1, 0]),
|
||||
BlockedLayout([1, 4], [1, 32], [2, 2], [1, 0]),
|
||||
MmaLayout(version=(2, 0), warps_per_cta=[4, 1])
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("M", [32, 64, 128, 256])
|
||||
@pytest.mark.parametrize("src_layout", layouts)
|
||||
def test_store_op(M, src_layout, device='cuda'):
|
||||
ir = f"""
|
||||
#src = {src_layout}
|
||||
module attributes {{"triton_gpu.num-warps" = 4 : i32}} {{
|
||||
tt.func public @kernel(%arg0: !tt.ptr<f32> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<f32> {{tt.divisibility = 16 : i32}}) {{
|
||||
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
|
||||
%1 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<{M}x!tt.ptr<f32>, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
|
||||
%2 = tt.addptr %1, %0 : tensor<{M}x!tt.ptr<f32>, #triton_gpu.slice<{{dim = 1, parent = #src}}>>, tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
|
||||
%3 = tt.load %2 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}xf32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
|
||||
%4 = tt.expand_dims %3 {{axis = 1 : i32}} : (tensor<{M}xf32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xf32, #src>
|
||||
%5 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
|
||||
%6 = tt.expand_dims %5 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src>
|
||||
%7 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<{M}x1x!tt.ptr<f32>, #src>
|
||||
%8 = tt.addptr %7, %6 : tensor<{M}x1x!tt.ptr<f32>, #src>, tensor<{M}x1xi32, #src>
|
||||
tt.store %8, %4 : tensor<{M}x1xf32, #src>
|
||||
tt.return
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
|
||||
import tempfile
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
|
||||
f.write(ir)
|
||||
f.flush()
|
||||
store_kernel = triton.compile(f.name)
|
||||
|
||||
rs = RandomState(17)
|
||||
x = rs.randint(0, 4, (M, 1)).astype('float32')
|
||||
y = np.zeros((M, 1), dtype='float32')
|
||||
x_tri = torch.tensor(x, device=device)
|
||||
y_tri = torch.tensor(y, device=device)
|
||||
|
||||
pgm = store_kernel[(1, 1, 1)](x_tri, y_tri)
|
||||
y_ref = x
|
||||
|
||||
np.testing.assert_allclose(y_ref, y_tri.cpu().numpy(), rtol=0.01, atol=1e-3)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2):
|
||||
delta = mean_2 - mean_1
|
||||
|
||||
@@ -1038,7 +1038,6 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: store_f32
|
||||
tt.func @store_f32(%arg0 : tensor<256x!tt.ptr<f32>, #blocked0>, %arg1 : tensor<256xf32, #blocked0>) {
|
||||
// CHECK: llvm.icmp "slt"
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: @$2 st.global.b32
|
||||
// CHECK: llvm.inline_asm
|
||||
|
||||
Reference in New Issue
Block a user