mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Fix masked load (#262)
* Fix the issue with masked load
Cherry-picked from a0b60eb187
* Remove tests in test_gemm that use too much LDS
---------
Co-authored-by: Shucai Xiao <shucai.xiao@amd.com>
This commit is contained in:
@@ -130,32 +130,36 @@ struct LoadOpConversion
|
||||
Value pred = mask ? maskElems[vecStart] : int_val(1, 1);
|
||||
for (size_t wordIdx = 0; wordIdx < nWords; ++wordIdx) {
|
||||
size_t elemOffset = vecStart + wordIdx * wordNElems;
|
||||
Value ptr = addrspacecast(ptrElems[elemOffset], ptr_ty(IntegerType::get(getContext(), width)));
|
||||
auto loaded = rewriter.create<scf::IfOp>(loc, pred,
|
||||
[&](OpBuilder &builder, Location loc){
|
||||
auto loadVal = builder.create<LLVM::LoadOp>(loc, ptr);
|
||||
builder.create<mlir::scf::YieldOp>(loc, ValueRange({loadVal}));
|
||||
},
|
||||
[&](OpBuilder &builder, Location loc){
|
||||
Value zeroVal = bitcast(int_val(valueElemNBits, 0), IntegerType::get(getContext(), width));
|
||||
Value otherVal;
|
||||
if (other) {
|
||||
auto vecTy = LLVM::getFixedVectorType(valueElemTy, wordNElems);
|
||||
Value v = undef(vecTy);
|
||||
for (size_t s = 0; s < wordNElems; ++s) {
|
||||
Value falseVal = otherElems[elemOffset + s];
|
||||
Value sVal = createIndexAttrConstant(
|
||||
rewriter, loc, this->getTypeConverter()->getIndexType(), s);
|
||||
v = insert_element(vecTy, v, falseVal, sVal);
|
||||
}
|
||||
otherVal = bitcast(v, IntegerType::get(getContext(), width));
|
||||
}
|
||||
Value falseVal = other ? otherVal : zeroVal;
|
||||
builder.create<mlir::scf::YieldOp>(loc, ValueRange({falseVal}));
|
||||
}
|
||||
);
|
||||
Value loadVal = bitcast(loaded->getResult(0), LLVM::getFixedVectorType(valueElemTy,
|
||||
wordNElems));
|
||||
Value ptr =
|
||||
addrspacecast(ptrElems[elemOffset],
|
||||
ptr_ty(IntegerType::get(getContext(), width)));
|
||||
auto loaded = rewriter.create<scf::IfOp>(
|
||||
loc, pred,
|
||||
[&](OpBuilder &builder, Location loc) {
|
||||
auto loadVal = builder.create<LLVM::LoadOp>(loc, ptr);
|
||||
builder.create<mlir::scf::YieldOp>(loc, ValueRange({loadVal}));
|
||||
},
|
||||
[&](OpBuilder &builder, Location loc) {
|
||||
Value zeroVal = int_val(width, 0);
|
||||
Value otherVal;
|
||||
if (other) {
|
||||
auto vecTy = LLVM::getFixedVectorType(valueElemTy, wordNElems);
|
||||
Value v = undef(vecTy);
|
||||
for (size_t s = 0; s < wordNElems; ++s) {
|
||||
Value falseVal = otherElems[elemOffset + s];
|
||||
Value sVal = createIndexAttrConstant(
|
||||
rewriter, loc, this->getTypeConverter()->getIndexType(),
|
||||
s);
|
||||
v = insert_element(vecTy, v, falseVal, sVal);
|
||||
}
|
||||
otherVal = bitcast(v, IntegerType::get(getContext(), width));
|
||||
}
|
||||
Value falseVal = other ? otherVal : zeroVal;
|
||||
builder.create<mlir::scf::YieldOp>(loc, ValueRange({falseVal}));
|
||||
});
|
||||
Value loadVal =
|
||||
bitcast(loaded->getResult(0),
|
||||
LLVM::getFixedVectorType(valueElemTy, wordNElems));
|
||||
for (size_t ii = 0; ii < wordNElems; ++ii) {
|
||||
Value vecIdx = createIndexAttrConstant(
|
||||
rewriter, loc, this->getTypeConverter()->getIndexType(), ii % wordNElems);
|
||||
|
||||
@@ -31,11 +31,11 @@ def matmul_no_scf_kernel(
|
||||
|
||||
|
||||
@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS', [
|
||||
[128, 256, 32, 4],
|
||||
[128, 64, 32, 4],
|
||||
[256, 128, 16, 4],
|
||||
[128, 16, 32, 4],
|
||||
[128, 32, 32, 4],
|
||||
[32, 128, 64, 4],
|
||||
[128, 128, 64, 4],
|
||||
[128, 32, 64, 4],
|
||||
[64, 128, 128, 4],
|
||||
[64, 128, 128, 2],
|
||||
])
|
||||
@@ -106,16 +106,16 @@ def get_variant_golden(a, b):
|
||||
[128, 64, 128, 4, 128, 64, 128],
|
||||
# K-Forloop
|
||||
[64, 32, 128, 4, 64, 32, 64],
|
||||
[128, 16, 128, 4, 128, 16, 32],
|
||||
[32, 16, 128, 4, 32, 16, 32],
|
||||
[128, 32, 128, 4, 128, 32, 32],
|
||||
[32, 32, 128, 4, 32, 32, 32],
|
||||
[32, 64, 128, 4, 32, 64, 32],
|
||||
[32, 128, 256, 4, 32, 128, 64],
|
||||
[64, 128, 64, 4, 64, 128, 32],
|
||||
[64, 64, 128, 4, 64, 64, 32],
|
||||
[128, 128, 64, 4, 128, 128, 32],
|
||||
[128, 128, 128, 4, 128, 128, 32],
|
||||
[128, 128, 256, 4, 128, 128, 64],
|
||||
[128, 256, 128, 4, 128, 256, 32],
|
||||
[128, 64, 64, 4, 128, 64, 32],
|
||||
[128, 64, 128, 4, 128, 64, 32],
|
||||
[64, 64, 256, 4, 64, 64, 64],
|
||||
[128, 64, 128, 4, 128, 64, 32],
|
||||
[256, 128, 64, 4, 256, 128, 16],
|
||||
[128, 64, 128, 4, 128, 64, 32],
|
||||
])
|
||||
|
||||
Reference in New Issue
Block a user