Fix masked load (#262)

* Fix the issue with masked load

Cherry-picked from a0b60eb187

* Remove tests in test_gemm that use too much LDS

---------

Co-authored-by: Shucai Xiao <shucai.xiao@amd.com>
This commit is contained in:
Lixun Zhang
2023-07-26 10:45:21 -05:00
committed by GitHub
parent 1cccf14f62
commit 2fbffe2784
2 changed files with 39 additions and 35 deletions

View File

@@ -130,32 +130,36 @@ struct LoadOpConversion
Value pred = mask ? maskElems[vecStart] : int_val(1, 1);
for (size_t wordIdx = 0; wordIdx < nWords; ++wordIdx) {
size_t elemOffset = vecStart + wordIdx * wordNElems;
Value ptr = addrspacecast(ptrElems[elemOffset], ptr_ty(IntegerType::get(getContext(), width)));
auto loaded = rewriter.create<scf::IfOp>(loc, pred,
[&](OpBuilder &builder, Location loc){
auto loadVal = builder.create<LLVM::LoadOp>(loc, ptr);
builder.create<mlir::scf::YieldOp>(loc, ValueRange({loadVal}));
},
[&](OpBuilder &builder, Location loc){
Value zeroVal = bitcast(int_val(valueElemNBits, 0), IntegerType::get(getContext(), width));
Value otherVal;
if (other) {
auto vecTy = LLVM::getFixedVectorType(valueElemTy, wordNElems);
Value v = undef(vecTy);
for (size_t s = 0; s < wordNElems; ++s) {
Value falseVal = otherElems[elemOffset + s];
Value sVal = createIndexAttrConstant(
rewriter, loc, this->getTypeConverter()->getIndexType(), s);
v = insert_element(vecTy, v, falseVal, sVal);
}
otherVal = bitcast(v, IntegerType::get(getContext(), width));
}
Value falseVal = other ? otherVal : zeroVal;
builder.create<mlir::scf::YieldOp>(loc, ValueRange({falseVal}));
}
);
Value loadVal = bitcast(loaded->getResult(0), LLVM::getFixedVectorType(valueElemTy,
wordNElems));
Value ptr =
addrspacecast(ptrElems[elemOffset],
ptr_ty(IntegerType::get(getContext(), width)));
auto loaded = rewriter.create<scf::IfOp>(
loc, pred,
[&](OpBuilder &builder, Location loc) {
auto loadVal = builder.create<LLVM::LoadOp>(loc, ptr);
builder.create<mlir::scf::YieldOp>(loc, ValueRange({loadVal}));
},
[&](OpBuilder &builder, Location loc) {
Value zeroVal = int_val(width, 0);
Value otherVal;
if (other) {
auto vecTy = LLVM::getFixedVectorType(valueElemTy, wordNElems);
Value v = undef(vecTy);
for (size_t s = 0; s < wordNElems; ++s) {
Value falseVal = otherElems[elemOffset + s];
Value sVal = createIndexAttrConstant(
rewriter, loc, this->getTypeConverter()->getIndexType(),
s);
v = insert_element(vecTy, v, falseVal, sVal);
}
otherVal = bitcast(v, IntegerType::get(getContext(), width));
}
Value falseVal = other ? otherVal : zeroVal;
builder.create<mlir::scf::YieldOp>(loc, ValueRange({falseVal}));
});
Value loadVal =
bitcast(loaded->getResult(0),
LLVM::getFixedVectorType(valueElemTy, wordNElems));
for (size_t ii = 0; ii < wordNElems; ++ii) {
Value vecIdx = createIndexAttrConstant(
rewriter, loc, this->getTypeConverter()->getIndexType(), ii % wordNElems);