mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] A general interface for initializing destination operands in load/store operations (#1427)
This commit is contained in:
@@ -108,7 +108,7 @@ struct LoadOpConversion
|
||||
}
|
||||
|
||||
// vectorized iteration through all the pointer/mask/other elements
|
||||
const int valueElemNbits =
|
||||
const int valueElemNBits =
|
||||
std::max(8u, valueElemTy.getIntOrFloatBitWidth());
|
||||
const int numVecs = numElems / vec;
|
||||
|
||||
@@ -117,11 +117,11 @@ struct LoadOpConversion
|
||||
// TODO: optimization when ptr is GEP with constant offset
|
||||
size_t in_off = 0;
|
||||
|
||||
const size_t maxWordWidth = std::max<size_t>(32, valueElemNbits);
|
||||
const size_t totalWidth = valueElemNbits * vec;
|
||||
const size_t maxWordWidth = std::max<size_t>(32, valueElemNBits);
|
||||
const size_t totalWidth = valueElemNBits * vec;
|
||||
const size_t width = std::min(totalWidth, maxWordWidth);
|
||||
const size_t nWords = std::max<size_t>(1, totalWidth / width);
|
||||
const size_t wordNElems = width / valueElemNbits;
|
||||
const size_t wordNElems = width / valueElemNBits;
|
||||
const size_t movWidth = width < 16 ? 16 : width;
|
||||
assert(wordNElems * nWords * numVecs == numElems);
|
||||
|
||||
@@ -138,18 +138,12 @@ struct LoadOpConversion
|
||||
const std::string writeConstraint =
|
||||
(width == 64) ? "=l" : ((width == 32) ? "=r" : "=c");
|
||||
|
||||
PTXInstr &init =
|
||||
ptxBuilder.create<>("mov")->o("u" + std::to_string(movWidth));
|
||||
PTXInstr::Operand *zero = ptxBuilder.newConstantOperand(0);
|
||||
|
||||
// prepare asm operands
|
||||
auto *dstsOpr = ptxBuilder.newListOperand();
|
||||
for (size_t wordIdx = 0; wordIdx < nWords; ++wordIdx) {
|
||||
auto *opr = ptxBuilder.newOperand(writeConstraint); // =r operations
|
||||
auto *opr = ptxBuilder.newOperand(writeConstraint,
|
||||
/*init=*/true); // =r operations
|
||||
dstsOpr->listAppend(opr);
|
||||
// Initialize the destination register, otherwise the register will
|
||||
// be undefined if the predicate is false.
|
||||
init(opr, zero);
|
||||
}
|
||||
|
||||
auto *addrOpr =
|
||||
@@ -186,7 +180,7 @@ struct LoadOpConversion
|
||||
PTXInstr &mov =
|
||||
ptxBuilder.create<>("mov")->o("u" + std::to_string(movWidth));
|
||||
|
||||
size_t size = width / valueElemNbits;
|
||||
size_t size = width / valueElemNBits;
|
||||
|
||||
auto vecTy = LLVM::getFixedVectorType(valueElemTy, size);
|
||||
Value v = undef(vecTy);
|
||||
@@ -201,8 +195,8 @@ struct LoadOpConversion
|
||||
PTXInstr::Operand *opr{};
|
||||
|
||||
if (otherIsSplatConstInt) {
|
||||
for (size_t s = 0; s < 32; s += valueElemNbits)
|
||||
splatVal |= splatVal << valueElemNbits;
|
||||
for (size_t s = 0; s < 32; s += valueElemNBits)
|
||||
splatVal |= splatVal << valueElemNBits;
|
||||
opr = ptxBuilder.newConstantOperand(splatVal);
|
||||
} else
|
||||
opr = ptxBuilder.newOperand(v, readConstraint);
|
||||
@@ -233,10 +227,10 @@ struct LoadOpConversion
|
||||
curr = ret;
|
||||
}
|
||||
curr = bitcast(curr, LLVM::getFixedVectorType(valueElemTy,
|
||||
width / valueElemNbits));
|
||||
width / valueElemNBits));
|
||||
rets.push_back(curr);
|
||||
}
|
||||
int tmp = width / valueElemNbits;
|
||||
int tmp = width / valueElemNBits;
|
||||
for (size_t ii = 0; ii < vec; ++ii) {
|
||||
Value vecIdx = createIndexAttrConstant(
|
||||
rewriter, loc, this->getTypeConverter()->getIndexType(), ii % tmp);
|
||||
@@ -312,18 +306,18 @@ struct StoreOpConversion
|
||||
|
||||
const size_t dtsize =
|
||||
std::max<int>(1, valueElemTy.getIntOrFloatBitWidth() / 8);
|
||||
const size_t valueElemNbits = dtsize * 8;
|
||||
const size_t valueElemNBits = dtsize * 8;
|
||||
|
||||
const int numVecs = elemsPerThread / vec;
|
||||
for (size_t vecStart = 0; vecStart < elemsPerThread; vecStart += vec) {
|
||||
// TODO: optimization when ptr is AddPtr with constant offset
|
||||
size_t in_off = 0;
|
||||
|
||||
const size_t maxWordWidth = std::max<size_t>(32, valueElemNbits);
|
||||
const size_t totalWidth = valueElemNbits * vec;
|
||||
const size_t maxWordWidth = std::max<size_t>(32, valueElemNBits);
|
||||
const size_t totalWidth = valueElemNBits * vec;
|
||||
const size_t width = std::min(totalWidth, maxWordWidth);
|
||||
const size_t nWords = std::max<size_t>(1, totalWidth / width);
|
||||
const size_t wordNElems = width / valueElemNbits;
|
||||
const size_t wordNElems = width / valueElemNBits;
|
||||
assert(wordNElems * nWords * numVecs == elemsPerThread);
|
||||
|
||||
// TODO(Superjomn) Add cache policy fields to StoreOp.
|
||||
@@ -414,6 +408,7 @@ struct AtomicCASOpConversion
|
||||
Type valueElemTy =
|
||||
TensorTy ? getTypeConverter()->convertType(TensorTy.getElementType())
|
||||
: op.getResult().getType();
|
||||
auto valueElemNBits = valueElemTy.getIntOrFloatBitWidth();
|
||||
auto tid = tid_val();
|
||||
Value pred = icmp_eq(tid, i32_val(0));
|
||||
PTXBuilder ptxBuilderMemfence;
|
||||
@@ -424,13 +419,12 @@ struct AtomicCASOpConversion
|
||||
|
||||
Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation());
|
||||
atomPtr = bitcast(atomPtr, ptr_ty(valueElemTy, 3));
|
||||
|
||||
Value casPtr = ptrElements[0];
|
||||
Value casCmp = cmpElements[0];
|
||||
Value casVal = valElements[0];
|
||||
|
||||
PTXBuilder ptxBuilderAtomicCAS;
|
||||
auto *dstOpr = ptxBuilderAtomicCAS.newOperand("=r");
|
||||
auto *dstOpr = ptxBuilderAtomicCAS.newOperand("=r", /*init=*/true);
|
||||
auto *ptrOpr = ptxBuilderAtomicCAS.newAddrOperand(casPtr, "l");
|
||||
auto *cmpOpr = ptxBuilderAtomicCAS.newOperand(casCmp, "r");
|
||||
auto *valOpr = ptxBuilderAtomicCAS.newOperand(casVal, "r");
|
||||
@@ -441,7 +435,7 @@ struct AtomicCASOpConversion
|
||||
barrier();
|
||||
|
||||
PTXBuilder ptxBuilderStore;
|
||||
auto *dstOprStore = ptxBuilderStore.newAddrOperand(atomPtr, "l");
|
||||
auto *dstOprStore = ptxBuilderStore.newAddrOperand(atomPtr, "r");
|
||||
auto *valOprStore = ptxBuilderStore.newOperand(old, "r");
|
||||
auto &st = *ptxBuilderStore.create<PTXInstr>("st");
|
||||
st.shared().o("b32");
|
||||
@@ -498,7 +492,7 @@ struct AtomicRMWOpConversion
|
||||
Type valueElemTy =
|
||||
tensorTy ? getTypeConverter()->convertType(tensorTy.getElementType())
|
||||
: op.getResult().getType();
|
||||
const size_t valueElemNbits = valueElemTy.getIntOrFloatBitWidth();
|
||||
const size_t valueElemNBits = valueElemTy.getIntOrFloatBitWidth();
|
||||
auto elemsPerThread = getElemsPerThread(val.getType());
|
||||
// vec = 1, numElements = 1 for scalar
|
||||
auto vec = getVectorSize(ptr);
|
||||
@@ -529,16 +523,16 @@ struct AtomicRMWOpConversion
|
||||
Value rmwMask = llMask ? and_(mask, maskElements[i]) : mask;
|
||||
std::string sTy;
|
||||
PTXBuilder ptxBuilderAtomicRMW;
|
||||
std::string tyId = valueElemNbits * vec == 64
|
||||
std::string tyId = valueElemNBits * vec == 64
|
||||
? "l"
|
||||
: (valueElemNbits * vec == 32 ? "r" : "h");
|
||||
auto *dstOpr = ptxBuilderAtomicRMW.newOperand("=" + tyId);
|
||||
: (valueElemNBits * vec == 32 ? "r" : "h");
|
||||
auto *dstOpr = ptxBuilderAtomicRMW.newOperand("=" + tyId, /*init=*/true);
|
||||
auto *ptrOpr = ptxBuilderAtomicRMW.newAddrOperand(rmwPtr, "l");
|
||||
auto *valOpr = ptxBuilderAtomicRMW.newOperand(rmwVal, tyId);
|
||||
|
||||
auto &atom = ptxBuilderAtomicRMW.create<>("atom")->global().o("gpu");
|
||||
auto rmwOp = stringifyRMWOp(atomicRmwAttr).str();
|
||||
auto sBits = std::to_string(valueElemNbits);
|
||||
auto sBits = std::to_string(valueElemNBits);
|
||||
switch (atomicRmwAttr) {
|
||||
case RMWOp::AND:
|
||||
sTy = "b" + sBits;
|
||||
@@ -554,9 +548,9 @@ struct AtomicRMWOpConversion
|
||||
break;
|
||||
case RMWOp::FADD:
|
||||
rmwOp = "add";
|
||||
rmwOp += (valueElemNbits == 16 ? ".noftz" : "");
|
||||
rmwOp += (valueElemNBits == 16 ? ".noftz" : "");
|
||||
sTy = "f" + sBits;
|
||||
sTy += (vec == 2 && valueElemNbits == 16) ? "x2" : "";
|
||||
sTy += (vec == 2 && valueElemNBits == 16) ? "x2" : "";
|
||||
break;
|
||||
case RMWOp::MAX:
|
||||
sTy = "s" + sBits;
|
||||
@@ -598,7 +592,14 @@ struct AtomicRMWOpConversion
|
||||
auto old = ptxBuilderAtomicRMW.launch(rewriter, loc, valueElemTy);
|
||||
Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation());
|
||||
atomPtr = bitcast(atomPtr, ptr_ty(valueElemTy, 3));
|
||||
store(old, atomPtr);
|
||||
// Only threads with rmwMask = True store the result
|
||||
PTXBuilder ptxBuilderStore;
|
||||
auto &storeShared =
|
||||
ptxBuilderStore.create<>("st")->shared().o("b" + sBits);
|
||||
auto *ptrOpr = ptxBuilderStore.newAddrOperand(atomPtr, "r");
|
||||
auto *valOpr = ptxBuilderStore.newOperand(old, tyId);
|
||||
storeShared(ptrOpr, valOpr).predicate(rmwMask);
|
||||
ptxBuilderStore.launch(rewriter, loc, void_ty(ctx));
|
||||
barrier();
|
||||
Value ret = load(atomPtr);
|
||||
barrier();
|
||||
|
||||
Reference in New Issue
Block a user