mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Fix for LLVM compiler's inability to auto vectorize float16 loads and
stores on AMDGPU.
This commit is contained in:
@@ -132,24 +132,41 @@ struct LoadOpConversion
|
||||
const size_t nWords = std::max<size_t>(1, totalWidth / width);
|
||||
const size_t wordNElems = width / valueElemNbits;
|
||||
assert(wordNElems * nWords * numVecs == numElems);
|
||||
|
||||
#ifdef USE_ROCM
|
||||
Value pred = mask ? maskElems[vecStart] : int_val(1, 1);
|
||||
Value zeroVal = bitcast(int_val(valueElemNbits, 0), valueElemTy);
|
||||
for (size_t wordIdx = 0; wordIdx < nWords; ++wordIdx) {
|
||||
for (size_t wordElem = 0; wordElem < wordNElems; ++wordElem) {
|
||||
size_t elemOffset = vecStart + wordIdx * wordNElems + wordElem;
|
||||
auto loaded = rewriter.create<scf::IfOp>(loc, pred,
|
||||
[&](OpBuilder &builder, Location loc){
|
||||
auto loadVal = builder.create<LLVM::LoadOp>(loc, ptrElems[elemOffset]);
|
||||
builder.create<mlir::scf::YieldOp>(loc, ValueRange({loadVal}));
|
||||
},
|
||||
[&](OpBuilder &builder, Location loc){
|
||||
Value otherVal = other ? otherElems[elemOffset] : zeroVal;
|
||||
builder.create<mlir::scf::YieldOp>(loc,ValueRange({otherVal}));
|
||||
size_t elemOffset = vecStart + wordIdx * wordNElems;
|
||||
Value ptr = addrspacecast(ptrElems[elemOffset], ptr_ty(IntegerType::get(getContext(), width)));
|
||||
auto loaded = rewriter.create<scf::IfOp>(loc, pred,
|
||||
[&](OpBuilder &builder, Location loc){
|
||||
auto loadVal = builder.create<LLVM::LoadOp>(loc, ptr);
|
||||
builder.create<mlir::scf::YieldOp>(loc, ValueRange({loadVal}));
|
||||
},
|
||||
[&](OpBuilder &builder, Location loc){
|
||||
Value zeroVal = bitcast(int_val(valueElemNbits, 0), IntegerType::get(getContext(), width));
|
||||
Value otherVal;
|
||||
if (other) {
|
||||
auto vecTy = LLVM::getFixedVectorType(valueElemTy, wordNElems);
|
||||
Value v = undef(vecTy);
|
||||
for (size_t s = 0; s < wordNElems; ++s) {
|
||||
Value falseVal = otherElems[elemOffset + s];
|
||||
Value sVal = createIndexAttrConstant(
|
||||
rewriter, loc, this->getTypeConverter()->getIndexType(), s);
|
||||
v = insert_element(vecTy, v, falseVal, sVal);
|
||||
}
|
||||
otherVal = bitcast(v, IntegerType::get(getContext(), width));
|
||||
}
|
||||
);
|
||||
loadedVals.push_back(loaded->getResult(0));
|
||||
Value falseVal = other ? otherVal : zeroVal;
|
||||
builder.create<mlir::scf::YieldOp>(loc, ValueRange({falseVal}));
|
||||
}
|
||||
);
|
||||
Value loadVal = bitcast(loaded->getResult(0), LLVM::getFixedVectorType(valueElemTy,
|
||||
wordNElems));
|
||||
for (size_t ii = 0; ii < wordNElems; ++ii) {
|
||||
Value vecIdx = createIndexAttrConstant(
|
||||
rewriter, loc, this->getTypeConverter()->getIndexType(), ii % wordNElems);
|
||||
Value loaded = extract_element(valueElemTy, loadVal, vecIdx);
|
||||
loadedVals.push_back(loaded);
|
||||
}
|
||||
}
|
||||
#else
|
||||
@@ -353,26 +370,25 @@ struct StoreOpConversion
|
||||
if (elem.getType().isInteger(1))
|
||||
elem = rewriter.create<LLVM::SExtOp>(loc, type::i8Ty(ctx), elem);
|
||||
elem = bitcast(elem, valueElemTy);
|
||||
#ifdef USE_ROCM
|
||||
Value maskVal = llMask ? maskElems[vecStart] : int_val(1, 1);
|
||||
rewriter.create<scf::IfOp>(loc, maskVal,
|
||||
[&](OpBuilder &builder, Location loc){
|
||||
auto storeOp = builder.create<LLVM::StoreOp>(loc, elem, ptrElems[elemOffset]);
|
||||
builder.create<scf::YieldOp>(loc);
|
||||
},
|
||||
nullptr
|
||||
);
|
||||
}
|
||||
}
|
||||
#else
|
||||
llWord = insert_element(wordTy, llWord, elem, i32_val(elemIdx));
|
||||
}
|
||||
llWord = bitcast(llWord, valArgTy);
|
||||
#ifdef USE_ROCM
|
||||
Value maskVal = llMask ? maskElems[vecStart] : int_val(1, 1);
|
||||
rewriter.create<scf::IfOp>(loc, maskVal,
|
||||
[&](OpBuilder &builder, Location loc){
|
||||
auto storeOp = builder.create<LLVM::StoreOp>(loc, llWord, ptrElems[vecStart + wordIdx * wordNElems]);
|
||||
builder.create<scf::YieldOp>(loc);
|
||||
},
|
||||
nullptr);
|
||||
#else
|
||||
std::string constraint =
|
||||
(width == 64) ? "l" : ((width == 32) ? "r" : "c");
|
||||
asmArgs.emplace_back(llWord, constraint);
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifndef USE_ROCM
|
||||
// Prepare the PTX inline asm.
|
||||
PTXBuilder ptxBuilder;
|
||||
auto *asmArgList = ptxBuilder.newListOperand(asmArgs);
|
||||
|
||||
@@ -35,6 +35,8 @@
|
||||
#define xor_(...) rewriter.create<LLVM::XOrOp>(loc, __VA_ARGS__)
|
||||
#define bitcast(val__, type__) \
|
||||
rewriter.create<LLVM::BitcastOp>(loc, type__, val__)
|
||||
#define addrspacecast(val__, type__) \
|
||||
rewriter.create<LLVM::AddrSpaceCastOp>(loc, type__, val__)
|
||||
#define gep(...) rewriter.create<LLVM::GEPOp>(loc, __VA_ARGS__)
|
||||
#define ptr_ty(...) LLVM::LLVMPointerType::get(__VA_ARGS__)
|
||||
#define insert_val(...) rewriter.create<LLVM::InsertValueOp>(loc, __VA_ARGS__)
|
||||
|
||||
Reference in New Issue
Block a user