Fix for LLVM compiler's inability to auto vectorize float16 loads and

stores on AMDGPU.
This commit is contained in:
Rohit Santhanam
2023-03-03 08:01:19 +00:00
parent f9bd9908a1
commit 20ef9a0908
2 changed files with 44 additions and 26 deletions

View File

@@ -132,24 +132,41 @@ struct LoadOpConversion
const size_t nWords = std::max<size_t>(1, totalWidth / width);
const size_t wordNElems = width / valueElemNbits;
assert(wordNElems * nWords * numVecs == numElems);
#ifdef USE_ROCM
Value pred = mask ? maskElems[vecStart] : int_val(1, 1);
Value zeroVal = bitcast(int_val(valueElemNbits, 0), valueElemTy);
for (size_t wordIdx = 0; wordIdx < nWords; ++wordIdx) {
for (size_t wordElem = 0; wordElem < wordNElems; ++wordElem) {
size_t elemOffset = vecStart + wordIdx * wordNElems + wordElem;
auto loaded = rewriter.create<scf::IfOp>(loc, pred,
[&](OpBuilder &builder, Location loc){
auto loadVal = builder.create<LLVM::LoadOp>(loc, ptrElems[elemOffset]);
builder.create<mlir::scf::YieldOp>(loc, ValueRange({loadVal}));
},
[&](OpBuilder &builder, Location loc){
Value otherVal = other ? otherElems[elemOffset] : zeroVal;
builder.create<mlir::scf::YieldOp>(loc,ValueRange({otherVal}));
size_t elemOffset = vecStart + wordIdx * wordNElems;
Value ptr = addrspacecast(ptrElems[elemOffset], ptr_ty(IntegerType::get(getContext(), width)));
auto loaded = rewriter.create<scf::IfOp>(loc, pred,
[&](OpBuilder &builder, Location loc){
auto loadVal = builder.create<LLVM::LoadOp>(loc, ptr);
builder.create<mlir::scf::YieldOp>(loc, ValueRange({loadVal}));
},
[&](OpBuilder &builder, Location loc){
Value zeroVal = bitcast(int_val(valueElemNbits, 0), IntegerType::get(getContext(), width));
Value otherVal;
if (other) {
auto vecTy = LLVM::getFixedVectorType(valueElemTy, wordNElems);
Value v = undef(vecTy);
for (size_t s = 0; s < wordNElems; ++s) {
Value falseVal = otherElems[elemOffset + s];
Value sVal = createIndexAttrConstant(
rewriter, loc, this->getTypeConverter()->getIndexType(), s);
v = insert_element(vecTy, v, falseVal, sVal);
}
otherVal = bitcast(v, IntegerType::get(getContext(), width));
}
);
loadedVals.push_back(loaded->getResult(0));
Value falseVal = other ? otherVal : zeroVal;
builder.create<mlir::scf::YieldOp>(loc, ValueRange({falseVal}));
}
);
Value loadVal = bitcast(loaded->getResult(0), LLVM::getFixedVectorType(valueElemTy,
wordNElems));
for (size_t ii = 0; ii < wordNElems; ++ii) {
Value vecIdx = createIndexAttrConstant(
rewriter, loc, this->getTypeConverter()->getIndexType(), ii % wordNElems);
Value loaded = extract_element(valueElemTy, loadVal, vecIdx);
loadedVals.push_back(loaded);
}
}
#else
@@ -353,26 +370,25 @@ struct StoreOpConversion
if (elem.getType().isInteger(1))
elem = rewriter.create<LLVM::SExtOp>(loc, type::i8Ty(ctx), elem);
elem = bitcast(elem, valueElemTy);
#ifdef USE_ROCM
Value maskVal = llMask ? maskElems[vecStart] : int_val(1, 1);
rewriter.create<scf::IfOp>(loc, maskVal,
[&](OpBuilder &builder, Location loc){
auto storeOp = builder.create<LLVM::StoreOp>(loc, elem, ptrElems[elemOffset]);
builder.create<scf::YieldOp>(loc);
},
nullptr
);
}
}
#else
llWord = insert_element(wordTy, llWord, elem, i32_val(elemIdx));
}
llWord = bitcast(llWord, valArgTy);
#ifdef USE_ROCM
Value maskVal = llMask ? maskElems[vecStart] : int_val(1, 1);
rewriter.create<scf::IfOp>(loc, maskVal,
[&](OpBuilder &builder, Location loc){
auto storeOp = builder.create<LLVM::StoreOp>(loc, llWord, ptrElems[vecStart + wordIdx * wordNElems]);
builder.create<scf::YieldOp>(loc);
},
nullptr);
#else
std::string constraint =
(width == 64) ? "l" : ((width == 32) ? "r" : "c");
asmArgs.emplace_back(llWord, constraint);
#endif
}
#ifndef USE_ROCM
// Prepare the PTX inline asm.
PTXBuilder ptxBuilder;
auto *asmArgList = ptxBuilder.newListOperand(asmArgs);

View File

@@ -35,6 +35,8 @@
#define xor_(...) rewriter.create<LLVM::XOrOp>(loc, __VA_ARGS__)
#define bitcast(val__, type__) \
rewriter.create<LLVM::BitcastOp>(loc, type__, val__)
#define addrspacecast(val__, type__) \
rewriter.create<LLVM::AddrSpaceCastOp>(loc, type__, val__)
#define gep(...) rewriter.create<LLVM::GEPOp>(loc, __VA_ARGS__)
#define ptr_ty(...) LLVM::LLVMPointerType::get(__VA_ARGS__)
#define insert_val(...) rewriter.create<LLVM::InsertValueOp>(loc, __VA_ARGS__)