mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
ROCM IFU: fix AtomicCASOpConversion segfault
This commit is contained in:
@@ -22,6 +22,7 @@ float_dtypes = ['float16', 'float32', 'float64']
|
||||
dtypes = int_dtypes + uint_dtypes + float_dtypes
|
||||
dtypes_with_bfloat16 = dtypes + ['bfloat16']
|
||||
torch_dtypes = ['bool'] + int_dtypes + ['uint8'] + float_dtypes + ['bfloat16']
|
||||
num_ctas_list = [1]
|
||||
|
||||
|
||||
def hip_skip():
|
||||
@@ -803,22 +804,24 @@ def test_tensor_atomic_rmw_block(device="cuda"):
|
||||
assert torch.min(x).item() == 0.0
|
||||
|
||||
|
||||
def test_atomic_cas():
|
||||
@pytest.mark.parametrize("sem", [None, 'acquire', 'release', 'acq_rel', 'relaxed'])
|
||||
@pytest.mark.parametrize("num_ctas", num_ctas_list)
|
||||
def test_atomic_cas(sem, num_ctas, device):
|
||||
# 1. make sure that atomic_cas changes the original value (Lock)
|
||||
@triton.jit
|
||||
def change_value(Lock):
|
||||
tl.atomic_cas(Lock, 0, 1)
|
||||
|
||||
Lock = torch.zeros((1,), device='cuda', dtype=torch.int32)
|
||||
change_value[(1,)](Lock)
|
||||
Lock = torch.zeros((1, ), device=device, dtype=torch.int32)
|
||||
change_value[(1, )](Lock)
|
||||
|
||||
assert (Lock[0] == 1)
|
||||
|
||||
# 2. only one block enters the critical section
|
||||
@triton.jit
|
||||
def serialized_add(data, Lock):
|
||||
def serialized_add(data, Lock, SEM: tl.constexpr):
|
||||
ptrs = data + tl.arange(0, 128)
|
||||
while tl.atomic_cas(Lock, 0, 1) == 1:
|
||||
while tl.atomic_cas(Lock, 0, 1, SEM) == 1:
|
||||
pass
|
||||
|
||||
tl.store(ptrs, tl.load(ptrs) + 1.0)
|
||||
@@ -826,11 +829,15 @@ def test_atomic_cas():
|
||||
# release lock
|
||||
tl.atomic_xchg(Lock, 0)
|
||||
|
||||
Lock = torch.zeros((1,), device='cuda', dtype=torch.int32)
|
||||
data = torch.zeros((128,), device='cuda', dtype=torch.float32)
|
||||
ref = torch.full((128,), 64.0)
|
||||
serialized_add[(64,)](data, Lock)
|
||||
Lock = torch.zeros((1, ), device=device, dtype=torch.int32)
|
||||
data = torch.zeros((128, ), device=device, dtype=torch.float32)
|
||||
ref = torch.full((128, ), 64.0)
|
||||
h = serialized_add[(64, )](data, Lock, SEM=sem, num_ctas=num_ctas)
|
||||
sem_str = "acq_rel" if sem is None else sem
|
||||
np.testing.assert_allclose(to_numpy(data), to_numpy(ref))
|
||||
if is_hip():
|
||||
return
|
||||
assert f"atom.global.{sem_str}" in h.asm["ptx"]
|
||||
|
||||
|
||||
# ---------------
|
||||
|
||||
Reference in New Issue
Block a user