mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] A general interface for initializing destination operands in load/store operations (#1427)
This commit is contained in:
@@ -144,7 +144,12 @@ struct PTXBuilder {
|
||||
|
||||
// Create a new operand which is written to, that is, the constraint starts
|
||||
// with "=", e.g. "=r".
|
||||
Operand *newOperand(StringRef constraint);
|
||||
// If the operand will be used in predicated execution,
|
||||
// users may want to initialize it before use.
|
||||
// Otherwise if the register is only used in the true branch or the false
|
||||
// branch but not both, the register is undefined and ptxas can perform
|
||||
// aggressive optimizations that may lead to incorrect results.
|
||||
Operand *newOperand(StringRef constraint, bool init = false);
|
||||
|
||||
// Create a constant integer operand.
|
||||
Operand *newConstantOperand(int64_t v);
|
||||
@@ -171,6 +176,8 @@ private:
|
||||
return argArchive.back().get();
|
||||
}
|
||||
|
||||
void initOperand(Operand *opr);
|
||||
|
||||
// Make the operands in argArchive follow the provided \param order.
|
||||
void reorderArgArchive(ArrayRef<Operand *> order) {
|
||||
assert(order.size() == argArchive.size());
|
||||
|
||||
@@ -108,7 +108,7 @@ struct LoadOpConversion
|
||||
}
|
||||
|
||||
// vectorized iteration through all the pointer/mask/other elements
|
||||
const int valueElemNbits =
|
||||
const int valueElemNBits =
|
||||
std::max(8u, valueElemTy.getIntOrFloatBitWidth());
|
||||
const int numVecs = numElems / vec;
|
||||
|
||||
@@ -117,11 +117,11 @@ struct LoadOpConversion
|
||||
// TODO: optimization when ptr is GEP with constant offset
|
||||
size_t in_off = 0;
|
||||
|
||||
const size_t maxWordWidth = std::max<size_t>(32, valueElemNbits);
|
||||
const size_t totalWidth = valueElemNbits * vec;
|
||||
const size_t maxWordWidth = std::max<size_t>(32, valueElemNBits);
|
||||
const size_t totalWidth = valueElemNBits * vec;
|
||||
const size_t width = std::min(totalWidth, maxWordWidth);
|
||||
const size_t nWords = std::max<size_t>(1, totalWidth / width);
|
||||
const size_t wordNElems = width / valueElemNbits;
|
||||
const size_t wordNElems = width / valueElemNBits;
|
||||
const size_t movWidth = width < 16 ? 16 : width;
|
||||
assert(wordNElems * nWords * numVecs == numElems);
|
||||
|
||||
@@ -138,18 +138,12 @@ struct LoadOpConversion
|
||||
const std::string writeConstraint =
|
||||
(width == 64) ? "=l" : ((width == 32) ? "=r" : "=c");
|
||||
|
||||
PTXInstr &init =
|
||||
ptxBuilder.create<>("mov")->o("u" + std::to_string(movWidth));
|
||||
PTXInstr::Operand *zero = ptxBuilder.newConstantOperand(0);
|
||||
|
||||
// prepare asm operands
|
||||
auto *dstsOpr = ptxBuilder.newListOperand();
|
||||
for (size_t wordIdx = 0; wordIdx < nWords; ++wordIdx) {
|
||||
auto *opr = ptxBuilder.newOperand(writeConstraint); // =r operations
|
||||
auto *opr = ptxBuilder.newOperand(writeConstraint,
|
||||
/*init=*/true); // =r operations
|
||||
dstsOpr->listAppend(opr);
|
||||
// Initialize the destination register, otherwise the register will
|
||||
// be undefined if the predicate is false.
|
||||
init(opr, zero);
|
||||
}
|
||||
|
||||
auto *addrOpr =
|
||||
@@ -186,7 +180,7 @@ struct LoadOpConversion
|
||||
PTXInstr &mov =
|
||||
ptxBuilder.create<>("mov")->o("u" + std::to_string(movWidth));
|
||||
|
||||
size_t size = width / valueElemNbits;
|
||||
size_t size = width / valueElemNBits;
|
||||
|
||||
auto vecTy = LLVM::getFixedVectorType(valueElemTy, size);
|
||||
Value v = undef(vecTy);
|
||||
@@ -201,8 +195,8 @@ struct LoadOpConversion
|
||||
PTXInstr::Operand *opr{};
|
||||
|
||||
if (otherIsSplatConstInt) {
|
||||
for (size_t s = 0; s < 32; s += valueElemNbits)
|
||||
splatVal |= splatVal << valueElemNbits;
|
||||
for (size_t s = 0; s < 32; s += valueElemNBits)
|
||||
splatVal |= splatVal << valueElemNBits;
|
||||
opr = ptxBuilder.newConstantOperand(splatVal);
|
||||
} else
|
||||
opr = ptxBuilder.newOperand(v, readConstraint);
|
||||
@@ -233,10 +227,10 @@ struct LoadOpConversion
|
||||
curr = ret;
|
||||
}
|
||||
curr = bitcast(curr, LLVM::getFixedVectorType(valueElemTy,
|
||||
width / valueElemNbits));
|
||||
width / valueElemNBits));
|
||||
rets.push_back(curr);
|
||||
}
|
||||
int tmp = width / valueElemNbits;
|
||||
int tmp = width / valueElemNBits;
|
||||
for (size_t ii = 0; ii < vec; ++ii) {
|
||||
Value vecIdx = createIndexAttrConstant(
|
||||
rewriter, loc, this->getTypeConverter()->getIndexType(), ii % tmp);
|
||||
@@ -312,18 +306,18 @@ struct StoreOpConversion
|
||||
|
||||
const size_t dtsize =
|
||||
std::max<int>(1, valueElemTy.getIntOrFloatBitWidth() / 8);
|
||||
const size_t valueElemNbits = dtsize * 8;
|
||||
const size_t valueElemNBits = dtsize * 8;
|
||||
|
||||
const int numVecs = elemsPerThread / vec;
|
||||
for (size_t vecStart = 0; vecStart < elemsPerThread; vecStart += vec) {
|
||||
// TODO: optimization when ptr is AddPtr with constant offset
|
||||
size_t in_off = 0;
|
||||
|
||||
const size_t maxWordWidth = std::max<size_t>(32, valueElemNbits);
|
||||
const size_t totalWidth = valueElemNbits * vec;
|
||||
const size_t maxWordWidth = std::max<size_t>(32, valueElemNBits);
|
||||
const size_t totalWidth = valueElemNBits * vec;
|
||||
const size_t width = std::min(totalWidth, maxWordWidth);
|
||||
const size_t nWords = std::max<size_t>(1, totalWidth / width);
|
||||
const size_t wordNElems = width / valueElemNbits;
|
||||
const size_t wordNElems = width / valueElemNBits;
|
||||
assert(wordNElems * nWords * numVecs == elemsPerThread);
|
||||
|
||||
// TODO(Superjomn) Add cache policy fields to StoreOp.
|
||||
@@ -414,6 +408,7 @@ struct AtomicCASOpConversion
|
||||
Type valueElemTy =
|
||||
TensorTy ? getTypeConverter()->convertType(TensorTy.getElementType())
|
||||
: op.getResult().getType();
|
||||
auto valueElemNBits = valueElemTy.getIntOrFloatBitWidth();
|
||||
auto tid = tid_val();
|
||||
Value pred = icmp_eq(tid, i32_val(0));
|
||||
PTXBuilder ptxBuilderMemfence;
|
||||
@@ -424,13 +419,12 @@ struct AtomicCASOpConversion
|
||||
|
||||
Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation());
|
||||
atomPtr = bitcast(atomPtr, ptr_ty(valueElemTy, 3));
|
||||
|
||||
Value casPtr = ptrElements[0];
|
||||
Value casCmp = cmpElements[0];
|
||||
Value casVal = valElements[0];
|
||||
|
||||
PTXBuilder ptxBuilderAtomicCAS;
|
||||
auto *dstOpr = ptxBuilderAtomicCAS.newOperand("=r");
|
||||
auto *dstOpr = ptxBuilderAtomicCAS.newOperand("=r", /*init=*/true);
|
||||
auto *ptrOpr = ptxBuilderAtomicCAS.newAddrOperand(casPtr, "l");
|
||||
auto *cmpOpr = ptxBuilderAtomicCAS.newOperand(casCmp, "r");
|
||||
auto *valOpr = ptxBuilderAtomicCAS.newOperand(casVal, "r");
|
||||
@@ -441,7 +435,7 @@ struct AtomicCASOpConversion
|
||||
barrier();
|
||||
|
||||
PTXBuilder ptxBuilderStore;
|
||||
auto *dstOprStore = ptxBuilderStore.newAddrOperand(atomPtr, "l");
|
||||
auto *dstOprStore = ptxBuilderStore.newAddrOperand(atomPtr, "r");
|
||||
auto *valOprStore = ptxBuilderStore.newOperand(old, "r");
|
||||
auto &st = *ptxBuilderStore.create<PTXInstr>("st");
|
||||
st.shared().o("b32");
|
||||
@@ -498,7 +492,7 @@ struct AtomicRMWOpConversion
|
||||
Type valueElemTy =
|
||||
tensorTy ? getTypeConverter()->convertType(tensorTy.getElementType())
|
||||
: op.getResult().getType();
|
||||
const size_t valueElemNbits = valueElemTy.getIntOrFloatBitWidth();
|
||||
const size_t valueElemNBits = valueElemTy.getIntOrFloatBitWidth();
|
||||
auto elemsPerThread = getElemsPerThread(val.getType());
|
||||
// vec = 1, numElements = 1 for scalar
|
||||
auto vec = getVectorSize(ptr);
|
||||
@@ -529,16 +523,16 @@ struct AtomicRMWOpConversion
|
||||
Value rmwMask = llMask ? and_(mask, maskElements[i]) : mask;
|
||||
std::string sTy;
|
||||
PTXBuilder ptxBuilderAtomicRMW;
|
||||
std::string tyId = valueElemNbits * vec == 64
|
||||
std::string tyId = valueElemNBits * vec == 64
|
||||
? "l"
|
||||
: (valueElemNbits * vec == 32 ? "r" : "h");
|
||||
auto *dstOpr = ptxBuilderAtomicRMW.newOperand("=" + tyId);
|
||||
: (valueElemNBits * vec == 32 ? "r" : "h");
|
||||
auto *dstOpr = ptxBuilderAtomicRMW.newOperand("=" + tyId, /*init=*/true);
|
||||
auto *ptrOpr = ptxBuilderAtomicRMW.newAddrOperand(rmwPtr, "l");
|
||||
auto *valOpr = ptxBuilderAtomicRMW.newOperand(rmwVal, tyId);
|
||||
|
||||
auto &atom = ptxBuilderAtomicRMW.create<>("atom")->global().o("gpu");
|
||||
auto rmwOp = stringifyRMWOp(atomicRmwAttr).str();
|
||||
auto sBits = std::to_string(valueElemNbits);
|
||||
auto sBits = std::to_string(valueElemNBits);
|
||||
switch (atomicRmwAttr) {
|
||||
case RMWOp::AND:
|
||||
sTy = "b" + sBits;
|
||||
@@ -554,9 +548,9 @@ struct AtomicRMWOpConversion
|
||||
break;
|
||||
case RMWOp::FADD:
|
||||
rmwOp = "add";
|
||||
rmwOp += (valueElemNbits == 16 ? ".noftz" : "");
|
||||
rmwOp += (valueElemNBits == 16 ? ".noftz" : "");
|
||||
sTy = "f" + sBits;
|
||||
sTy += (vec == 2 && valueElemNbits == 16) ? "x2" : "";
|
||||
sTy += (vec == 2 && valueElemNBits == 16) ? "x2" : "";
|
||||
break;
|
||||
case RMWOp::MAX:
|
||||
sTy = "s" + sBits;
|
||||
@@ -598,7 +592,14 @@ struct AtomicRMWOpConversion
|
||||
auto old = ptxBuilderAtomicRMW.launch(rewriter, loc, valueElemTy);
|
||||
Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation());
|
||||
atomPtr = bitcast(atomPtr, ptr_ty(valueElemTy, 3));
|
||||
store(old, atomPtr);
|
||||
// Only threads with rmwMask = True store the result
|
||||
PTXBuilder ptxBuilderStore;
|
||||
auto &storeShared =
|
||||
ptxBuilderStore.create<>("st")->shared().o("b" + sBits);
|
||||
auto *ptrOpr = ptxBuilderStore.newAddrOperand(atomPtr, "r");
|
||||
auto *valOpr = ptxBuilderStore.newOperand(old, tyId);
|
||||
storeShared(ptrOpr, valOpr).predicate(rmwMask);
|
||||
ptxBuilderStore.launch(rewriter, loc, void_ty(ctx));
|
||||
barrier();
|
||||
Value ret = load(atomPtr);
|
||||
barrier();
|
||||
|
||||
@@ -19,12 +19,34 @@ PTXBuilder::newOperand(mlir::Value value, StringRef constraint,
|
||||
return opr;
|
||||
}
|
||||
|
||||
PTXBuilder::Operand *PTXBuilder::newOperand(StringRef constraint) {
|
||||
void PTXBuilder::initOperand(Operand *opr) {
|
||||
auto numBits = 0;
|
||||
// Derive numBits from the constraint.
|
||||
if (opr->constraint[1] == 'c' || opr->constraint[1] == 'h')
|
||||
numBits = 16;
|
||||
else if (opr->constraint[1] == 'r')
|
||||
numBits = 32;
|
||||
else if (opr->constraint[1] == 'l')
|
||||
numBits = 64;
|
||||
else
|
||||
llvm_unreachable(("Unknown constraint: " + opr->constraint).c_str());
|
||||
// If numBits is less than 16, we use 16 as default because PTX does not
|
||||
// support 8-bit mov.
|
||||
numBits = numBits < 16 ? 16 : numBits;
|
||||
auto *zero = newConstantOperand(0);
|
||||
auto &init = create<>("mov")->o("u" + std::to_string(numBits));
|
||||
init(opr, zero);
|
||||
}
|
||||
|
||||
PTXBuilder::Operand *PTXBuilder::newOperand(StringRef constraint, bool init) {
|
||||
// Constraint should be something like "=r"
|
||||
assert(!constraint.empty() && constraint[0] == '=');
|
||||
assert(constraint.size() == 2 && constraint[0] == '=');
|
||||
auto *opr = newOperand();
|
||||
opr->idx = oprCounter++;
|
||||
opr->constraint = constraint;
|
||||
if (init) {
|
||||
initOperand(opr);
|
||||
}
|
||||
return opr;
|
||||
}
|
||||
|
||||
|
||||
@@ -739,6 +739,17 @@ def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||
|
||||
|
||||
def test_atomic_rmw_predicate(device="cuda"):
|
||||
@triton.jit
|
||||
def kernel(X):
|
||||
val = tl.program_id(0)
|
||||
if val < 64:
|
||||
tl.atomic_max(X, val)
|
||||
x = torch.zeros((1,), device=device, dtype=torch.int32)
|
||||
kernel[(4096,)](x)
|
||||
assert x.item() == 63
|
||||
|
||||
|
||||
@pytest.mark.parametrize("shape, axis",
|
||||
[(shape, axis) for shape in [(2, 2), (2, 8), (8, 2), (8, 8), (32, 32)] for axis in [0, 1]])
|
||||
def test_tensor_atomic_rmw(shape, axis, device="cuda"):
|
||||
|
||||
Reference in New Issue
Block a user