mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Mask out wrapped threads in store ops (#1283)
This commit is contained in:
@@ -258,7 +258,6 @@ struct StoreOpConversion
|
||||
matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Value ptr = op.getPtr();
|
||||
Value mask = op.getMask();
|
||||
Value value = op.getValue();
|
||||
|
||||
Value llPtr = adaptor.getPtr();
|
||||
@@ -273,7 +272,7 @@ struct StoreOpConversion
|
||||
typeConverter->convertType(getElementTypeOrSelf(valueTy));
|
||||
|
||||
unsigned vec = getVectorSize(ptr);
|
||||
unsigned numElems = getElemsPerThread(ptr.getType());
|
||||
unsigned elemsPerThread = getElemsPerThread(ptr.getType());
|
||||
|
||||
auto ptrElems = getTypeConverter()->unpackLLElements(loc, llPtr, rewriter,
|
||||
ptr.getType());
|
||||
@@ -284,6 +283,7 @@ struct StoreOpConversion
|
||||
// Determine the vectorization size
|
||||
SmallVector<Value> maskElems;
|
||||
if (llMask) {
|
||||
Value mask = op.getMask();
|
||||
maskElems = getTypeConverter()->unpackLLElements(loc, llMask, rewriter,
|
||||
mask.getType());
|
||||
assert(valueElems.size() == maskElems.size());
|
||||
@@ -292,12 +292,20 @@ struct StoreOpConversion
|
||||
vec = std::min(vec, maskAlign);
|
||||
}
|
||||
|
||||
// numElements = 1 for scalar
|
||||
auto tensorTy = valueTy.dyn_cast<RankedTensorType>();
|
||||
auto numElems = tensorTy ? tensorTy.getNumElements() : 1;
|
||||
Value mask = int_val(1, 1);
|
||||
auto tid = tid_val();
|
||||
mask = and_(mask,
|
||||
icmp_slt(mul(tid, i32_val(elemsPerThread)), i32_val(numElems)));
|
||||
|
||||
const size_t dtsize =
|
||||
std::max<int>(1, valueElemTy.getIntOrFloatBitWidth() / 8);
|
||||
const size_t valueElemNbits = dtsize * 8;
|
||||
|
||||
const int numVecs = numElems / vec;
|
||||
for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) {
|
||||
const int numVecs = elemsPerThread / vec;
|
||||
for (size_t vecStart = 0; vecStart < elemsPerThread; vecStart += vec) {
|
||||
// TODO: optimization when ptr is AddPtr with constant offset
|
||||
size_t in_off = 0;
|
||||
|
||||
@@ -306,7 +314,7 @@ struct StoreOpConversion
|
||||
const size_t width = std::min(totalWidth, maxWordWidth);
|
||||
const size_t nWords = std::max<size_t>(1, totalWidth / width);
|
||||
const size_t wordNElems = width / valueElemNbits;
|
||||
assert(wordNElems * nWords * numVecs == numElems);
|
||||
assert(wordNElems * nWords * numVecs == elemsPerThread);
|
||||
|
||||
// TODO(Superjomn) Add cache policy fields to StoreOp.
|
||||
// TODO(Superjomn) Deal with cache policy here.
|
||||
@@ -339,7 +347,7 @@ struct StoreOpConversion
|
||||
PTXBuilder ptxBuilder;
|
||||
auto *asmArgList = ptxBuilder.newListOperand(asmArgs);
|
||||
|
||||
Value maskVal = llMask ? maskElems[vecStart] : int_val(1, 1);
|
||||
Value maskVal = llMask ? and_(mask, maskElems[vecStart]) : mask;
|
||||
|
||||
auto *asmAddr =
|
||||
ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off);
|
||||
@@ -392,10 +400,10 @@ struct AtomicCASOpConversion
|
||||
auto valElements = getTypeConverter()->unpackLLElements(
|
||||
loc, llVal, rewriter, op.getVal().getType());
|
||||
|
||||
auto valueTy = op.getResult().getType().dyn_cast<RankedTensorType>();
|
||||
auto TensorTy = op.getResult().getType().dyn_cast<RankedTensorType>();
|
||||
Type valueElemTy =
|
||||
valueTy ? getTypeConverter()->convertType(valueTy.getElementType())
|
||||
: op.getResult().getType();
|
||||
TensorTy ? getTypeConverter()->convertType(TensorTy.getElementType())
|
||||
: op.getResult().getType();
|
||||
auto tid = tid_val();
|
||||
Value pred = icmp_eq(tid, i32_val(0));
|
||||
PTXBuilder ptxBuilderMemfence;
|
||||
@@ -462,7 +470,6 @@ struct AtomicRMWOpConversion
|
||||
|
||||
Value val = op.getVal();
|
||||
Value ptr = op.getPtr();
|
||||
Value _mask = op.getMask();
|
||||
|
||||
Value llPtr = adaptor.getPtr();
|
||||
Value llVal = adaptor.getVal();
|
||||
@@ -472,29 +479,31 @@ struct AtomicRMWOpConversion
|
||||
loc, llVal, rewriter, val.getType());
|
||||
auto ptrElements = getTypeConverter()->unpackLLElements(
|
||||
loc, llPtr, rewriter, ptr.getType());
|
||||
auto maskElements = getTypeConverter()->unpackLLElements(
|
||||
loc, llMask, rewriter, _mask.getType());
|
||||
SmallVector<Value> maskElements;
|
||||
if (llMask)
|
||||
maskElements = getTypeConverter()->unpackLLElements(
|
||||
loc, llMask, rewriter, op.getMask().getType());
|
||||
|
||||
auto valueTy = op.getResult().getType().dyn_cast<RankedTensorType>();
|
||||
auto tensorTy = op.getResult().getType().dyn_cast<RankedTensorType>();
|
||||
Type valueElemTy =
|
||||
valueTy ? getTypeConverter()->convertType(valueTy.getElementType())
|
||||
: op.getResult().getType();
|
||||
tensorTy ? getTypeConverter()->convertType(tensorTy.getElementType())
|
||||
: op.getResult().getType();
|
||||
const size_t valueElemNbits = valueElemTy.getIntOrFloatBitWidth();
|
||||
auto elemsPerThread = getElemsPerThread(val.getType());
|
||||
// vec = 1 for scalar
|
||||
// vec = 1, numElements = 1 for scalar
|
||||
auto vec = getVectorSize(ptr);
|
||||
Value mask = int_val(1, 1);
|
||||
auto tid = tid_val();
|
||||
int numElems = 1;
|
||||
// tensor
|
||||
if (valueTy) {
|
||||
if (tensorTy) {
|
||||
auto valTy = val.getType().cast<RankedTensorType>();
|
||||
vec = std::min<unsigned>(vec, valTy.getElementType().isF16() ? 2 : 1);
|
||||
// mask
|
||||
auto shape = valueTy.getShape();
|
||||
auto numElements = product(shape);
|
||||
mask = and_(mask, icmp_slt(mul(tid, i32_val(elemsPerThread)),
|
||||
i32_val(numElements)));
|
||||
numElems = tensorTy.getNumElements();
|
||||
}
|
||||
Value mask = int_val(1, 1);
|
||||
auto tid = tid_val();
|
||||
mask = and_(mask,
|
||||
icmp_slt(mul(tid, i32_val(elemsPerThread)), i32_val(numElems)));
|
||||
|
||||
auto vecTy = vec_ty(valueElemTy, vec);
|
||||
SmallVector<Value> resultVals(elemsPerThread);
|
||||
@@ -507,8 +516,7 @@ struct AtomicRMWOpConversion
|
||||
}
|
||||
|
||||
Value rmwPtr = ptrElements[i];
|
||||
Value rmwMask = maskElements[i];
|
||||
rmwMask = and_(rmwMask, mask);
|
||||
Value rmwMask = llMask ? and_(mask, maskElements[i]) : mask;
|
||||
std::string sTy;
|
||||
PTXBuilder ptxBuilderAtomicRMW;
|
||||
std::string tyId = valueElemNbits * vec == 64
|
||||
@@ -561,7 +569,7 @@ struct AtomicRMWOpConversion
|
||||
return failure();
|
||||
}
|
||||
atom.o(rmwOp).o(sTy);
|
||||
if (valueTy) {
|
||||
if (tensorTy) {
|
||||
atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask);
|
||||
auto retType = vec == 1 ? valueElemTy : vecTy;
|
||||
auto ret = ptxBuilderAtomicRMW.launch(rewriter, loc, retType);
|
||||
@@ -587,8 +595,8 @@ struct AtomicRMWOpConversion
|
||||
rewriter.replaceOp(op, {ret});
|
||||
}
|
||||
}
|
||||
if (valueTy) {
|
||||
Type structTy = getTypeConverter()->convertType(valueTy);
|
||||
if (tensorTy) {
|
||||
Type structTy = getTypeConverter()->convertType(tensorTy);
|
||||
Value resultStruct = getTypeConverter()->packLLElements(
|
||||
loc, resultVals, rewriter, structTy);
|
||||
rewriter.replaceOp(op, {resultStruct});
|
||||
|
||||
@@ -798,6 +798,20 @@ def test_store_constant(dtype_str):
|
||||
assert torch.all(output == ref)
|
||||
|
||||
|
||||
def test_load_store_same_ptr():
|
||||
@triton.jit()
|
||||
def kernel(in_out_ptr):
|
||||
pid = tl.program_id(axis=0)
|
||||
x = tl.load(in_out_ptr + pid)
|
||||
out = x * 2
|
||||
tl.store(in_out_ptr + pid, out)
|
||||
|
||||
for _ in range(1000):
|
||||
x = torch.ones((65536,), device="cuda", dtype=torch.float32)
|
||||
kernel[(65536,)](x, num_warps=32)
|
||||
assert torch.all(x == 2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
def test_f8_xf16_roundtrip(dtype):
|
||||
"""Tests that converting an f8 to f16 and back to f8 doesn't change its value"""
|
||||
|
||||
@@ -1001,14 +1001,60 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: atomic_add_f32
|
||||
func.func @atomic_add_f32(%arg0 : tensor<256x!tt.ptr<f32>, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) {
|
||||
// CHECK: llvm.icmp "slt"
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: atom.global.gpu.add.f32
|
||||
// CHECK-SAME: @$3 atom.global.gpu.add.f32
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: @$3 atom.global.gpu.add.f32
|
||||
%0 = "tt.atomic_rmw" (%arg0, %arg2, %arg1) {atomic_rmw_op = 5 : i32} : (tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xf32, #blocked0>, tensor<256xi1, #blocked0>) -> tensor<256xf32, #blocked0>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: atomic_add_f32_scalar
|
||||
func.func @atomic_add_f32_scalar(%arg0 : !tt.ptr<f32>, %arg1 : i1, %arg2 : f32) {
|
||||
// CHECK: llvm.icmp "eq"
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: @$3 atom.global.gpu.add.f32
|
||||
%0 = "tt.atomic_rmw" (%arg0, %arg2, %arg1) {atomic_rmw_op = 5 : i32} : (!tt.ptr<f32>, f32, i1) -> f32
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: store_f32
|
||||
func.func @store_f32(%arg0 : tensor<256x!tt.ptr<f32>, #blocked0>, %arg1 : tensor<256xf32, #blocked0>) {
|
||||
// CHECK: llvm.icmp "slt"
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: @$2 st.global.b32
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: @$2 st.global.b32
|
||||
tt.store %arg0, %arg1 : tensor<256xf32, #blocked0>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: store_f32_scalar
|
||||
func.func @store_f32_scalar(%arg0 : !tt.ptr<f32>, %arg1 : f32) {
|
||||
// CHECK: llvm.icmp "slt"
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: @$2 st.global.b32
|
||||
tt.store %arg0, %arg1 : f32
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
|
||||
Reference in New Issue
Block a user