[BACKEND] Mask out wrapped threads in store ops (#1283)

This commit is contained in:
Keren Zhou
2023-03-06 14:50:20 -08:00
committed by GitHub
parent 5e92a66267
commit 4731f300d3
3 changed files with 97 additions and 29 deletions

View File

@@ -258,7 +258,6 @@ struct StoreOpConversion
matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value ptr = op.getPtr();
Value mask = op.getMask();
Value value = op.getValue();
Value llPtr = adaptor.getPtr();
@@ -273,7 +272,7 @@ struct StoreOpConversion
typeConverter->convertType(getElementTypeOrSelf(valueTy));
unsigned vec = getVectorSize(ptr);
unsigned numElems = getElemsPerThread(ptr.getType());
unsigned elemsPerThread = getElemsPerThread(ptr.getType());
auto ptrElems = getTypeConverter()->unpackLLElements(loc, llPtr, rewriter,
ptr.getType());
@@ -284,6 +283,7 @@ struct StoreOpConversion
// Determine the vectorization size
SmallVector<Value> maskElems;
if (llMask) {
Value mask = op.getMask();
maskElems = getTypeConverter()->unpackLLElements(loc, llMask, rewriter,
mask.getType());
assert(valueElems.size() == maskElems.size());
@@ -292,12 +292,20 @@ struct StoreOpConversion
vec = std::min(vec, maskAlign);
}
// numElements = 1 for scalar
auto tensorTy = valueTy.dyn_cast<RankedTensorType>();
auto numElems = tensorTy ? tensorTy.getNumElements() : 1;
Value mask = int_val(1, 1);
auto tid = tid_val();
mask = and_(mask,
icmp_slt(mul(tid, i32_val(elemsPerThread)), i32_val(numElems)));
const size_t dtsize =
std::max<int>(1, valueElemTy.getIntOrFloatBitWidth() / 8);
const size_t valueElemNbits = dtsize * 8;
const int numVecs = numElems / vec;
for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) {
const int numVecs = elemsPerThread / vec;
for (size_t vecStart = 0; vecStart < elemsPerThread; vecStart += vec) {
// TODO: optimization when ptr is AddPtr with constant offset
size_t in_off = 0;
@@ -306,7 +314,7 @@ struct StoreOpConversion
const size_t width = std::min(totalWidth, maxWordWidth);
const size_t nWords = std::max<size_t>(1, totalWidth / width);
const size_t wordNElems = width / valueElemNbits;
assert(wordNElems * nWords * numVecs == numElems);
assert(wordNElems * nWords * numVecs == elemsPerThread);
// TODO(Superjomn) Add cache policy fields to StoreOp.
// TODO(Superjomn) Deal with cache policy here.
@@ -339,7 +347,7 @@ struct StoreOpConversion
PTXBuilder ptxBuilder;
auto *asmArgList = ptxBuilder.newListOperand(asmArgs);
Value maskVal = llMask ? maskElems[vecStart] : int_val(1, 1);
Value maskVal = llMask ? and_(mask, maskElems[vecStart]) : mask;
auto *asmAddr =
ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off);
@@ -392,10 +400,10 @@ struct AtomicCASOpConversion
auto valElements = getTypeConverter()->unpackLLElements(
loc, llVal, rewriter, op.getVal().getType());
auto valueTy = op.getResult().getType().dyn_cast<RankedTensorType>();
auto TensorTy = op.getResult().getType().dyn_cast<RankedTensorType>();
Type valueElemTy =
valueTy ? getTypeConverter()->convertType(valueTy.getElementType())
: op.getResult().getType();
TensorTy ? getTypeConverter()->convertType(TensorTy.getElementType())
: op.getResult().getType();
auto tid = tid_val();
Value pred = icmp_eq(tid, i32_val(0));
PTXBuilder ptxBuilderMemfence;
@@ -462,7 +470,6 @@ struct AtomicRMWOpConversion
Value val = op.getVal();
Value ptr = op.getPtr();
Value _mask = op.getMask();
Value llPtr = adaptor.getPtr();
Value llVal = adaptor.getVal();
@@ -472,29 +479,31 @@ struct AtomicRMWOpConversion
loc, llVal, rewriter, val.getType());
auto ptrElements = getTypeConverter()->unpackLLElements(
loc, llPtr, rewriter, ptr.getType());
auto maskElements = getTypeConverter()->unpackLLElements(
loc, llMask, rewriter, _mask.getType());
SmallVector<Value> maskElements;
if (llMask)
maskElements = getTypeConverter()->unpackLLElements(
loc, llMask, rewriter, op.getMask().getType());
auto valueTy = op.getResult().getType().dyn_cast<RankedTensorType>();
auto tensorTy = op.getResult().getType().dyn_cast<RankedTensorType>();
Type valueElemTy =
valueTy ? getTypeConverter()->convertType(valueTy.getElementType())
: op.getResult().getType();
tensorTy ? getTypeConverter()->convertType(tensorTy.getElementType())
: op.getResult().getType();
const size_t valueElemNbits = valueElemTy.getIntOrFloatBitWidth();
auto elemsPerThread = getElemsPerThread(val.getType());
// vec = 1 for scalar
// vec = 1, numElements = 1 for scalar
auto vec = getVectorSize(ptr);
Value mask = int_val(1, 1);
auto tid = tid_val();
int numElems = 1;
// tensor
if (valueTy) {
if (tensorTy) {
auto valTy = val.getType().cast<RankedTensorType>();
vec = std::min<unsigned>(vec, valTy.getElementType().isF16() ? 2 : 1);
// mask
auto shape = valueTy.getShape();
auto numElements = product(shape);
mask = and_(mask, icmp_slt(mul(tid, i32_val(elemsPerThread)),
i32_val(numElements)));
numElems = tensorTy.getNumElements();
}
Value mask = int_val(1, 1);
auto tid = tid_val();
mask = and_(mask,
icmp_slt(mul(tid, i32_val(elemsPerThread)), i32_val(numElems)));
auto vecTy = vec_ty(valueElemTy, vec);
SmallVector<Value> resultVals(elemsPerThread);
@@ -507,8 +516,7 @@ struct AtomicRMWOpConversion
}
Value rmwPtr = ptrElements[i];
Value rmwMask = maskElements[i];
rmwMask = and_(rmwMask, mask);
Value rmwMask = llMask ? and_(mask, maskElements[i]) : mask;
std::string sTy;
PTXBuilder ptxBuilderAtomicRMW;
std::string tyId = valueElemNbits * vec == 64
@@ -561,7 +569,7 @@ struct AtomicRMWOpConversion
return failure();
}
atom.o(rmwOp).o(sTy);
if (valueTy) {
if (tensorTy) {
atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask);
auto retType = vec == 1 ? valueElemTy : vecTy;
auto ret = ptxBuilderAtomicRMW.launch(rewriter, loc, retType);
@@ -587,8 +595,8 @@ struct AtomicRMWOpConversion
rewriter.replaceOp(op, {ret});
}
}
if (valueTy) {
Type structTy = getTypeConverter()->convertType(valueTy);
if (tensorTy) {
Type structTy = getTypeConverter()->convertType(tensorTy);
Value resultStruct = getTypeConverter()->packLLElements(
loc, resultVals, rewriter, structTy);
rewriter.replaceOp(op, {resultStruct});