[BACKEND] Add store cache modifiers (#1826)

Plumb through store cache modifiers.
This commit is contained in:
Thomas
2023-06-23 09:29:10 -07:00
committed by GitHub
parent 2eb7bc4b4c
commit 3d1cd89b54
5 changed files with 56 additions and 4 deletions

View File

@@ -2266,6 +2266,37 @@ def test_vectorization_hints(has_hints):
# test store
# ---------------
@pytest.mark.parametrize("cache", ["", ".wb", ".cg", ".cs"])
def test_store_cache_modifier(cache):
src = torch.empty(128, device='cuda')
dst = torch.empty(128, device='cuda')
@triton.jit
def _kernel(dst, src, CACHE: tl.constexpr):
offsets = tl.arange(0, 128)
x = tl.load(src + offsets)
tl.store(dst + offsets, x, cache_modifier=CACHE)
pgm = _kernel[(1,)](dst, src, CACHE=cache)
ptx = pgm.asm['ptx']
if cache == '':
assert 'st.global.wb' not in ptx
assert 'st.global.cg' not in ptx
assert 'st.global.cs' not in ptx
if cache == '.wb':
assert 'st.global.wb' in ptx
assert 'st.global.cg' not in ptx
assert 'st.global.cs' not in ptx
if cache == '.cg':
assert 'st.global.wb' not in ptx
assert 'st.global.cg' in ptx
assert 'st.global.cs' not in ptx
if cache == '.cs':
assert 'st.global.wb' not in ptx
assert 'st.global.cg' not in ptx
assert 'st.global.cs' in ptx
# ---------------
# test if
# ---------------