mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Mask out wrapped threads in store ops (#1283)
This commit is contained in:
@@ -798,6 +798,20 @@ def test_store_constant(dtype_str):
|
||||
assert torch.all(output == ref)
|
||||
|
||||
|
||||
def test_load_store_same_ptr():
|
||||
@triton.jit()
|
||||
def kernel(in_out_ptr):
|
||||
pid = tl.program_id(axis=0)
|
||||
x = tl.load(in_out_ptr + pid)
|
||||
out = x * 2
|
||||
tl.store(in_out_ptr + pid, out)
|
||||
|
||||
for _ in range(1000):
|
||||
x = torch.ones((65536,), device="cuda", dtype=torch.float32)
|
||||
kernel[(65536,)](x, num_warps=32)
|
||||
assert torch.all(x == 2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
def test_f8_xf16_roundtrip(dtype):
|
||||
"""Tests that converting an f8 to f16 and back to f8 doesn't change its value"""
|
||||
|
||||
Reference in New Issue
Block a user