mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Add store cache modifiers (#1826)
Plumb through store cache modifiers.
This commit is contained in:
@@ -2266,6 +2266,37 @@ def test_vectorization_hints(has_hints):
|
||||
# test store
|
||||
# ---------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize("cache", ["", ".wb", ".cg", ".cs"])
|
||||
def test_store_cache_modifier(cache):
|
||||
src = torch.empty(128, device='cuda')
|
||||
dst = torch.empty(128, device='cuda')
|
||||
|
||||
@triton.jit
|
||||
def _kernel(dst, src, CACHE: tl.constexpr):
|
||||
offsets = tl.arange(0, 128)
|
||||
x = tl.load(src + offsets)
|
||||
tl.store(dst + offsets, x, cache_modifier=CACHE)
|
||||
|
||||
pgm = _kernel[(1,)](dst, src, CACHE=cache)
|
||||
ptx = pgm.asm['ptx']
|
||||
if cache == '':
|
||||
assert 'st.global.wb' not in ptx
|
||||
assert 'st.global.cg' not in ptx
|
||||
assert 'st.global.cs' not in ptx
|
||||
if cache == '.wb':
|
||||
assert 'st.global.wb' in ptx
|
||||
assert 'st.global.cg' not in ptx
|
||||
assert 'st.global.cs' not in ptx
|
||||
if cache == '.cg':
|
||||
assert 'st.global.wb' not in ptx
|
||||
assert 'st.global.cg' in ptx
|
||||
assert 'st.global.cs' not in ptx
|
||||
if cache == '.cs':
|
||||
assert 'st.global.wb' not in ptx
|
||||
assert 'st.global.cg' not in ptx
|
||||
assert 'st.global.cs' in ptx
|
||||
|
||||
# ---------------
|
||||
# test if
|
||||
# ---------------
|
||||
|
||||
Reference in New Issue
Block a user