Fix masked load (#262)

* Fix the issue with masked load

Cherry-picked from a0b60eb187

* Remove tests in test_gemm that use too much LDS

---------

Co-authored-by: Shucai Xiao <shucai.xiao@amd.com>
This commit is contained in:
Lixun Zhang
2023-07-26 10:45:21 -05:00
committed by GitHub
parent 1cccf14f62
commit 2fbffe2784
2 changed files with 39 additions and 35 deletions

View File

@@ -130,32 +130,36 @@ struct LoadOpConversion
Value pred = mask ? maskElems[vecStart] : int_val(1, 1);
for (size_t wordIdx = 0; wordIdx < nWords; ++wordIdx) {
size_t elemOffset = vecStart + wordIdx * wordNElems;
Value ptr = addrspacecast(ptrElems[elemOffset], ptr_ty(IntegerType::get(getContext(), width)));
auto loaded = rewriter.create<scf::IfOp>(loc, pred,
[&](OpBuilder &builder, Location loc){
auto loadVal = builder.create<LLVM::LoadOp>(loc, ptr);
builder.create<mlir::scf::YieldOp>(loc, ValueRange({loadVal}));
},
[&](OpBuilder &builder, Location loc){
Value zeroVal = bitcast(int_val(valueElemNBits, 0), IntegerType::get(getContext(), width));
Value otherVal;
if (other) {
auto vecTy = LLVM::getFixedVectorType(valueElemTy, wordNElems);
Value v = undef(vecTy);
for (size_t s = 0; s < wordNElems; ++s) {
Value falseVal = otherElems[elemOffset + s];
Value sVal = createIndexAttrConstant(
rewriter, loc, this->getTypeConverter()->getIndexType(), s);
v = insert_element(vecTy, v, falseVal, sVal);
}
otherVal = bitcast(v, IntegerType::get(getContext(), width));
}
Value falseVal = other ? otherVal : zeroVal;
builder.create<mlir::scf::YieldOp>(loc, ValueRange({falseVal}));
}
);
Value loadVal = bitcast(loaded->getResult(0), LLVM::getFixedVectorType(valueElemTy,
wordNElems));
Value ptr =
addrspacecast(ptrElems[elemOffset],
ptr_ty(IntegerType::get(getContext(), width)));
auto loaded = rewriter.create<scf::IfOp>(
loc, pred,
[&](OpBuilder &builder, Location loc) {
auto loadVal = builder.create<LLVM::LoadOp>(loc, ptr);
builder.create<mlir::scf::YieldOp>(loc, ValueRange({loadVal}));
},
[&](OpBuilder &builder, Location loc) {
Value zeroVal = int_val(width, 0);
Value otherVal;
if (other) {
auto vecTy = LLVM::getFixedVectorType(valueElemTy, wordNElems);
Value v = undef(vecTy);
for (size_t s = 0; s < wordNElems; ++s) {
Value falseVal = otherElems[elemOffset + s];
Value sVal = createIndexAttrConstant(
rewriter, loc, this->getTypeConverter()->getIndexType(),
s);
v = insert_element(vecTy, v, falseVal, sVal);
}
otherVal = bitcast(v, IntegerType::get(getContext(), width));
}
Value falseVal = other ? otherVal : zeroVal;
builder.create<mlir::scf::YieldOp>(loc, ValueRange({falseVal}));
});
Value loadVal =
bitcast(loaded->getResult(0),
LLVM::getFixedVectorType(valueElemTy, wordNElems));
for (size_t ii = 0; ii < wordNElems; ++ii) {
Value vecIdx = createIndexAttrConstant(
rewriter, loc, this->getTypeConverter()->getIndexType(), ii % wordNElems);

View File

@@ -31,11 +31,11 @@ def matmul_no_scf_kernel(
@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS', [
[128, 256, 32, 4],
[128, 64, 32, 4],
[256, 128, 16, 4],
[128, 16, 32, 4],
[128, 32, 32, 4],
[32, 128, 64, 4],
[128, 128, 64, 4],
[128, 32, 64, 4],
[64, 128, 128, 4],
[64, 128, 128, 2],
])
@@ -106,16 +106,16 @@ def get_variant_golden(a, b):
[128, 64, 128, 4, 128, 64, 128],
# K-Forloop
[64, 32, 128, 4, 64, 32, 64],
[128, 16, 128, 4, 128, 16, 32],
[32, 16, 128, 4, 32, 16, 32],
[128, 32, 128, 4, 128, 32, 32],
[32, 32, 128, 4, 32, 32, 32],
[32, 64, 128, 4, 32, 64, 32],
[32, 128, 256, 4, 32, 128, 64],
[64, 128, 64, 4, 64, 128, 32],
[64, 64, 128, 4, 64, 64, 32],
[128, 128, 64, 4, 128, 128, 32],
[128, 128, 128, 4, 128, 128, 32],
[128, 128, 256, 4, 128, 128, 64],
[128, 256, 128, 4, 128, 256, 32],
[128, 64, 64, 4, 128, 64, 32],
[128, 64, 128, 4, 128, 64, 32],
[64, 64, 256, 4, 64, 64, 64],
[128, 64, 128, 4, 128, 64, 32],
[256, 128, 64, 4, 256, 128, 16],
[128, 64, 128, 4, 128, 64, 32],
])