match masked load

This commit is contained in:
Michael Melesse
2023-04-11 15:20:08 -05:00
parent 9a8e334859
commit f50116208f
3 changed files with 8 additions and 5 deletions

View File

@@ -124,6 +124,7 @@ struct LoadOpConversion
const size_t wordNElems = width / valueElemNBits;
const size_t movWidth = width < 16 ? 16 : width;
assert(wordNElems * nWords * numVecs == numElems);
#ifdef USE_ROCM
Value pred = mask ? maskElems[vecStart] : int_val(1, 1);
for (size_t wordIdx = 0; wordIdx < nWords; ++wordIdx) {
@@ -376,6 +377,7 @@ struct StoreOpConversion
if (elem.getType().isInteger(1))
elem = sext(i8_ty, elem);
elem = bitcast(elem, valueElemTy);
llWord = insert_element(wordTy, llWord, elem, i32_val(elemIdx));
}
llWord = bitcast(llWord, valArgTy);

View File

@@ -1470,15 +1470,15 @@ def test_masked_load_shared_memory(dtype, device='cuda'):
in2_offsets = K_offsets[:, None] * in2_stride + N_offsets[None, :]
# Load inputs.
x = tl.load(in1_ptr + in_offsets, mask=in_offsets < in_numel)
w = tl.load(in2_ptr + in2_offsets, mask=in2_offsets < in2_numel)
x = tl.load(in1_ptr + in_offsets, mask=in_offsets < M * K)
w = tl.load(in2_ptr + in2_offsets, mask=in2_offsets < K * N)
# Without a dot product the memory doesn't get promoted to shared.
o = tl.dot(x, w)
o = tl.dot(x, w, out_dtype=tl.float32)
# Store output
output_offsets = M_offsets[:, None] * out_stride + N_offsets[None, :]
tl.store(output_ptr + output_offsets, o, mask=output_offsets < in2_numel)
tl.store(output_ptr + output_offsets, o, mask=output_offsets < M * N)
pgm = _kernel[(1,)](in1, in2, out,
in1.stride()[0],

View File

@@ -14,12 +14,13 @@ chmod -R 777 $LOG_DIR
sh scripts/amd/clean.sh
# UNIT_TEST="python/test/unit/language/test_core_amd.py"
UNIT_TEST="python/test/unit/language/test_core_amd.py::test_shift_op[int8-int8-<<]"
# UNIT_TEST="python/test/unit/language/test_core_amd.py::test_shift_op[int8-int8-<<]"
# UNIT_TEST="python/test/unit/language/test_core_amd.py::test_shift_op[int32-int32->>]"
# UNIT_TEST="python/test/unit/language/test_core.py::test_empty_kernel[float32]"
# UNIT_TEST="python/test/unit/language/test_core.py::test_bin_op"
# UNIT_TEST="python/test/unit/language/test_core.py::test_bin_op[float32-float32-+]"
# UNIT_TEST="python/test/unit/language/test_core.py::test_bin_op[int8-float16-%]"
UNIT_TEST="python/test/unit/language/test_core.py::test_masked_load_shared_memory[dtype0]"
# UNIT_TEST="python/test/unit/language/test_elementwise.py"
# check for backtrace