mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND][BACKEND] Add acquire/release semantics for atomics (#1739)
This commit is contained in:
@@ -937,15 +937,16 @@ def test_noinline(mode):
|
||||
# ---------------
|
||||
# test atomics
|
||||
# ---------------
|
||||
@pytest.mark.parametrize("op, dtype_x_str, mode", itertools.chain.from_iterable([
|
||||
@pytest.mark.parametrize("op, dtype_x_str, mode, sem", itertools.chain.from_iterable([
|
||||
[
|
||||
('add', 'float16', mode),
|
||||
('add', 'uint32', mode), ('add', 'int32', mode), ('add', 'float32', mode),
|
||||
('max', 'uint32', mode), ('max', 'int32', mode), ('max', 'float32', mode),
|
||||
('min', 'uint32', mode), ('min', 'int32', mode), ('min', 'float32', mode),
|
||||
('add', 'float16', mode, sem),
|
||||
('add', 'uint32', mode, sem), ('add', 'int32', mode, sem), ('add', 'float32', mode, sem),
|
||||
('max', 'uint32', mode, sem), ('max', 'int32', mode, sem), ('max', 'float32', mode, sem),
|
||||
('min', 'uint32', mode, sem), ('min', 'int32', mode, sem), ('min', 'float32', mode, sem),
|
||||
]
|
||||
for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']]))
|
||||
def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
|
||||
for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']
|
||||
for sem in [None, 'acquire', 'release', 'acq_rel', 'relaxed']]))
|
||||
def test_atomic_rmw(op, dtype_x_str, mode, sem, device='cuda'):
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if capability[0] < 7:
|
||||
if dtype_x_str == 'float16':
|
||||
@@ -959,7 +960,8 @@ def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
|
||||
x = tl.load(X + pid)
|
||||
old = GENERATE_TEST_HERE
|
||||
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.atomic_{op}(Z, x)'})
|
||||
sem_arg = sem if sem is None else f'"{sem}"'
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.atomic_{op}(Z, x, sem={sem_arg})'})
|
||||
numpy_op = {'add': np.sum, 'max': np.max, 'min': np.min}[op]
|
||||
max_neutral = float('-inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).min
|
||||
min_neutral = float('inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).max
|
||||
@@ -981,7 +983,7 @@ def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
|
||||
x_tri = to_triton(x, device=device)
|
||||
|
||||
z_tri = to_triton(np.array([neutral], dtype=getattr(np, dtype_x_str)), device=device)
|
||||
kernel[(n_programs, )](x_tri, z_tri)
|
||||
h = kernel[(n_programs, )](x_tri, z_tri)
|
||||
# torch result
|
||||
z_ref = numpy_op(x).astype(getattr(np, dtype_x_str))
|
||||
# compare
|
||||
@@ -990,6 +992,8 @@ def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
|
||||
assert z_ref.item() == to_numpy(z_tri).item()
|
||||
else:
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
|
||||
sem_str = "acq_rel" if sem is None else sem
|
||||
assert f"atom.global.gpu.{sem_str}" in h.asm["ptx"]
|
||||
|
||||
|
||||
def test_atomic_rmw_predicate(device="cuda"):
|
||||
@@ -1047,7 +1051,8 @@ def test_tensor_atomic_rmw_block(device="cuda"):
|
||||
assert torch.min(x).item() == 0.0
|
||||
|
||||
|
||||
def test_atomic_cas():
|
||||
@pytest.mark.parametrize("sem", [None, 'acquire', 'release', 'acq_rel', 'relaxed'])
|
||||
def test_atomic_cas(sem):
|
||||
# 1. make sure that atomic_cas changes the original value (Lock)
|
||||
@triton.jit
|
||||
def change_value(Lock):
|
||||
@@ -1060,9 +1065,9 @@ def test_atomic_cas():
|
||||
|
||||
# 2. only one block enters the critical section
|
||||
@triton.jit
|
||||
def serialized_add(data, Lock):
|
||||
def serialized_add(data, Lock, SEM: tl.constexpr):
|
||||
ptrs = data + tl.arange(0, 128)
|
||||
while tl.atomic_cas(Lock, 0, 1) == 1:
|
||||
while tl.atomic_cas(Lock, 0, 1, SEM) == 1:
|
||||
pass
|
||||
|
||||
tl.store(ptrs, tl.load(ptrs) + 1.0)
|
||||
@@ -1073,8 +1078,10 @@ def test_atomic_cas():
|
||||
Lock = torch.zeros((1,), device='cuda', dtype=torch.int32)
|
||||
data = torch.zeros((128,), device='cuda', dtype=torch.float32)
|
||||
ref = torch.full((128,), 64.0)
|
||||
serialized_add[(64,)](data, Lock)
|
||||
h = serialized_add[(64,)](data, Lock, SEM=sem)
|
||||
sem_str = "acq_rel" if sem is None else sem
|
||||
np.testing.assert_allclose(to_numpy(data), to_numpy(ref))
|
||||
assert f"atom.global.{sem_str}" in h.asm["ptx"]
|
||||
|
||||
|
||||
# ---------------
|
||||
|
||||
Reference in New Issue
Block a user