[FRONTEND][BACKEND] Add acquire/release semantics for atomics (#1739)

This commit is contained in:
Philippe Tillet
2023-06-05 19:09:13 -07:00
committed by GitHub
parent 9c8d7c18b3
commit c52a91231a
9 changed files with 141 additions and 73 deletions

View File

@@ -937,15 +937,16 @@ def test_noinline(mode):
# ---------------
# test atomics
# ---------------
@pytest.mark.parametrize("op, dtype_x_str, mode", itertools.chain.from_iterable([
@pytest.mark.parametrize("op, dtype_x_str, mode, sem", itertools.chain.from_iterable([
[
('add', 'float16', mode),
('add', 'uint32', mode), ('add', 'int32', mode), ('add', 'float32', mode),
('max', 'uint32', mode), ('max', 'int32', mode), ('max', 'float32', mode),
('min', 'uint32', mode), ('min', 'int32', mode), ('min', 'float32', mode),
('add', 'float16', mode, sem),
('add', 'uint32', mode, sem), ('add', 'int32', mode, sem), ('add', 'float32', mode, sem),
('max', 'uint32', mode, sem), ('max', 'int32', mode, sem), ('max', 'float32', mode, sem),
('min', 'uint32', mode, sem), ('min', 'int32', mode, sem), ('min', 'float32', mode, sem),
]
for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']]))
def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']
for sem in [None, 'acquire', 'release', 'acq_rel', 'relaxed']]))
def test_atomic_rmw(op, dtype_x_str, mode, sem, device='cuda'):
capability = torch.cuda.get_device_capability()
if capability[0] < 7:
if dtype_x_str == 'float16':
@@ -959,7 +960,8 @@ def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
x = tl.load(X + pid)
old = GENERATE_TEST_HERE
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.atomic_{op}(Z, x)'})
sem_arg = sem if sem is None else f'"{sem}"'
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.atomic_{op}(Z, x, sem={sem_arg})'})
numpy_op = {'add': np.sum, 'max': np.max, 'min': np.min}[op]
max_neutral = float('-inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).min
min_neutral = float('inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).max
@@ -981,7 +983,7 @@ def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
x_tri = to_triton(x, device=device)
z_tri = to_triton(np.array([neutral], dtype=getattr(np, dtype_x_str)), device=device)
kernel[(n_programs, )](x_tri, z_tri)
h = kernel[(n_programs, )](x_tri, z_tri)
# torch result
z_ref = numpy_op(x).astype(getattr(np, dtype_x_str))
# compare
@@ -990,6 +992,8 @@ def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
assert z_ref.item() == to_numpy(z_tri).item()
else:
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
sem_str = "acq_rel" if sem is None else sem
assert f"atom.global.gpu.{sem_str}" in h.asm["ptx"]
def test_atomic_rmw_predicate(device="cuda"):
@@ -1047,7 +1051,8 @@ def test_tensor_atomic_rmw_block(device="cuda"):
assert torch.min(x).item() == 0.0
def test_atomic_cas():
@pytest.mark.parametrize("sem", [None, 'acquire', 'release', 'acq_rel', 'relaxed'])
def test_atomic_cas(sem):
# 1. make sure that atomic_cas changes the original value (Lock)
@triton.jit
def change_value(Lock):
@@ -1060,9 +1065,9 @@ def test_atomic_cas():
# 2. only one block enters the critical section
@triton.jit
def serialized_add(data, Lock):
def serialized_add(data, Lock, SEM: tl.constexpr):
ptrs = data + tl.arange(0, 128)
while tl.atomic_cas(Lock, 0, 1) == 1:
while tl.atomic_cas(Lock, 0, 1, SEM) == 1:
pass
tl.store(ptrs, tl.load(ptrs) + 1.0)
@@ -1073,8 +1078,10 @@ def test_atomic_cas():
Lock = torch.zeros((1,), device='cuda', dtype=torch.int32)
data = torch.zeros((128,), device='cuda', dtype=torch.float32)
ref = torch.full((128,), 64.0)
serialized_add[(64,)](data, Lock)
h = serialized_add[(64,)](data, Lock, SEM=sem)
sem_str = "acq_rel" if sem is None else sem
np.testing.assert_allclose(to_numpy(data), to_numpy(ref))
assert f"atom.global.{sem_str}" in h.asm["ptx"]
# ---------------