ROCM IFU: fix AtomicCASOpConversion segfault

This commit is contained in:
Michael Melesse
2023-12-12 17:40:31 -06:00
parent a42ac260aa
commit 6efc013e46
3 changed files with 112 additions and 46 deletions

View File

@@ -22,6 +22,7 @@ float_dtypes = ['float16', 'float32', 'float64']
dtypes = int_dtypes + uint_dtypes + float_dtypes
dtypes_with_bfloat16 = dtypes + ['bfloat16']
torch_dtypes = ['bool'] + int_dtypes + ['uint8'] + float_dtypes + ['bfloat16']
num_ctas_list = [1]
def hip_skip():
@@ -803,22 +804,24 @@ def test_tensor_atomic_rmw_block(device="cuda"):
assert torch.min(x).item() == 0.0
def test_atomic_cas():
@pytest.mark.parametrize("sem", [None, 'acquire', 'release', 'acq_rel', 'relaxed'])
@pytest.mark.parametrize("num_ctas", num_ctas_list)
def test_atomic_cas(sem, num_ctas, device):
# 1. make sure that atomic_cas changes the original value (Lock)
@triton.jit
def change_value(Lock):
tl.atomic_cas(Lock, 0, 1)
Lock = torch.zeros((1,), device='cuda', dtype=torch.int32)
change_value[(1,)](Lock)
Lock = torch.zeros((1, ), device=device, dtype=torch.int32)
change_value[(1, )](Lock)
assert (Lock[0] == 1)
# 2. only one block enters the critical section
@triton.jit
def serialized_add(data, Lock):
def serialized_add(data, Lock, SEM: tl.constexpr):
ptrs = data + tl.arange(0, 128)
while tl.atomic_cas(Lock, 0, 1) == 1:
while tl.atomic_cas(Lock, 0, 1, SEM) == 1:
pass
tl.store(ptrs, tl.load(ptrs) + 1.0)
@@ -826,11 +829,15 @@ def test_atomic_cas():
# release lock
tl.atomic_xchg(Lock, 0)
Lock = torch.zeros((1,), device='cuda', dtype=torch.int32)
data = torch.zeros((128,), device='cuda', dtype=torch.float32)
ref = torch.full((128,), 64.0)
serialized_add[(64,)](data, Lock)
Lock = torch.zeros((1, ), device=device, dtype=torch.int32)
data = torch.zeros((128, ), device=device, dtype=torch.float32)
ref = torch.full((128, ), 64.0)
h = serialized_add[(64, )](data, Lock, SEM=sem, num_ctas=num_ctas)
sem_str = "acq_rel" if sem is None else sem
np.testing.assert_allclose(to_numpy(data), to_numpy(ref))
if is_hip():
return
assert f"atom.global.{sem_str}" in h.asm["ptx"]
# ---------------