mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Fix masked load (#262)
* Fix the issue with masked load
Cherry-picked from a0b60eb187
* Remove tests in test_gemm that use too much LDS
---------
Co-authored-by: Shucai Xiao <shucai.xiao@amd.com>
This commit is contained in:
@@ -130,32 +130,36 @@ struct LoadOpConversion
|
||||
Value pred = mask ? maskElems[vecStart] : int_val(1, 1);
|
||||
for (size_t wordIdx = 0; wordIdx < nWords; ++wordIdx) {
|
||||
size_t elemOffset = vecStart + wordIdx * wordNElems;
|
||||
Value ptr = addrspacecast(ptrElems[elemOffset], ptr_ty(IntegerType::get(getContext(), width)));
|
||||
auto loaded = rewriter.create<scf::IfOp>(loc, pred,
|
||||
[&](OpBuilder &builder, Location loc){
|
||||
auto loadVal = builder.create<LLVM::LoadOp>(loc, ptr);
|
||||
builder.create<mlir::scf::YieldOp>(loc, ValueRange({loadVal}));
|
||||
},
|
||||
[&](OpBuilder &builder, Location loc){
|
||||
Value zeroVal = bitcast(int_val(valueElemNBits, 0), IntegerType::get(getContext(), width));
|
||||
Value otherVal;
|
||||
if (other) {
|
||||
auto vecTy = LLVM::getFixedVectorType(valueElemTy, wordNElems);
|
||||
Value v = undef(vecTy);
|
||||
for (size_t s = 0; s < wordNElems; ++s) {
|
||||
Value falseVal = otherElems[elemOffset + s];
|
||||
Value sVal = createIndexAttrConstant(
|
||||
rewriter, loc, this->getTypeConverter()->getIndexType(), s);
|
||||
v = insert_element(vecTy, v, falseVal, sVal);
|
||||
}
|
||||
otherVal = bitcast(v, IntegerType::get(getContext(), width));
|
||||
}
|
||||
Value falseVal = other ? otherVal : zeroVal;
|
||||
builder.create<mlir::scf::YieldOp>(loc, ValueRange({falseVal}));
|
||||
}
|
||||
);
|
||||
Value loadVal = bitcast(loaded->getResult(0), LLVM::getFixedVectorType(valueElemTy,
|
||||
wordNElems));
|
||||
Value ptr =
|
||||
addrspacecast(ptrElems[elemOffset],
|
||||
ptr_ty(IntegerType::get(getContext(), width)));
|
||||
auto loaded = rewriter.create<scf::IfOp>(
|
||||
loc, pred,
|
||||
[&](OpBuilder &builder, Location loc) {
|
||||
auto loadVal = builder.create<LLVM::LoadOp>(loc, ptr);
|
||||
builder.create<mlir::scf::YieldOp>(loc, ValueRange({loadVal}));
|
||||
},
|
||||
[&](OpBuilder &builder, Location loc) {
|
||||
Value zeroVal = int_val(width, 0);
|
||||
Value otherVal;
|
||||
if (other) {
|
||||
auto vecTy = LLVM::getFixedVectorType(valueElemTy, wordNElems);
|
||||
Value v = undef(vecTy);
|
||||
for (size_t s = 0; s < wordNElems; ++s) {
|
||||
Value falseVal = otherElems[elemOffset + s];
|
||||
Value sVal = createIndexAttrConstant(
|
||||
rewriter, loc, this->getTypeConverter()->getIndexType(),
|
||||
s);
|
||||
v = insert_element(vecTy, v, falseVal, sVal);
|
||||
}
|
||||
otherVal = bitcast(v, IntegerType::get(getContext(), width));
|
||||
}
|
||||
Value falseVal = other ? otherVal : zeroVal;
|
||||
builder.create<mlir::scf::YieldOp>(loc, ValueRange({falseVal}));
|
||||
});
|
||||
Value loadVal =
|
||||
bitcast(loaded->getResult(0),
|
||||
LLVM::getFixedVectorType(valueElemTy, wordNElems));
|
||||
for (size_t ii = 0; ii < wordNElems; ++ii) {
|
||||
Value vecIdx = createIndexAttrConstant(
|
||||
rewriter, loc, this->getTypeConverter()->getIndexType(), ii % wordNElems);
|
||||
|
||||
Reference in New Issue
Block a user