mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] interpreter rewrite (#2321)
This is a new interpreter mode that shares semantic analysis with the JIT'ed codepath and that the Triton core team is committed to maintain
This commit is contained in:
@@ -1,69 +0,0 @@
|
||||
import random
|
||||
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.interpreter.interpreter import program_ids_from_grid
|
||||
|
||||
|
||||
def test_addition():
|
||||
|
||||
@triton.jit(interpret=True)
|
||||
def add_kernel(
|
||||
x_ptr,
|
||||
y_ptr,
|
||||
output_ptr,
|
||||
n_elements,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
x = tl.load(x_ptr + offsets, mask=mask)
|
||||
y = tl.load(y_ptr + offsets, mask=mask)
|
||||
output = x + y
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
a = torch.rand((128,), device="cuda")
|
||||
b = torch.rand((128,), device="cuda")
|
||||
expected = a + b
|
||||
output = torch.empty((128,), device="cuda")
|
||||
|
||||
def grid(meta):
|
||||
return (triton.cdiv(128, meta["BLOCK_SIZE"]),)
|
||||
|
||||
add_kernel[grid](a, b, output, 128, BLOCK_SIZE=32)
|
||||
|
||||
assert torch.allclose(expected, output, atol=1e-2, rtol=0)
|
||||
|
||||
|
||||
def test_program_ids_from_grid():
|
||||
random.seed(123)
|
||||
grid = (3, 4)
|
||||
expected_combinations = 3 * 4
|
||||
unique_combinations = set(program_ids_from_grid(grid))
|
||||
assert len(unique_combinations) == expected_combinations
|
||||
|
||||
first_run = list(program_ids_from_grid(grid))
|
||||
second_run = list(program_ids_from_grid(grid))
|
||||
assert first_run != second_run
|
||||
|
||||
|
||||
def test_atomic():
|
||||
@triton.jit(interpret=True)
|
||||
def atomic(
|
||||
x_ptr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
tl.atomic_add(x_ptr + pid, 1)
|
||||
t = tl.atomic_xchg(x_ptr + pid, 3)
|
||||
t += 1 # 2
|
||||
tl.atomic_cas(x_ptr + pid, 3, t) # match
|
||||
tl.atomic_cas(x_ptr + pid, 40, 9) # no match
|
||||
nb_dim = 16
|
||||
a = torch.zeros((nb_dim, ), dtype=torch.int32, device="cuda")
|
||||
|
||||
atomic[(nb_dim, )](a)
|
||||
assert torch.allclose(a, torch.full_like(a, 2))
|
||||
@@ -2421,7 +2421,6 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
|
||||
if epilogue == 'chain-dot':
|
||||
z_ref = np.matmul(z_ref, w)
|
||||
# compare
|
||||
# print(z_ref[:,0], z_tri[:,0])
|
||||
if in_dtype == 'float32':
|
||||
# XXX: Somehow there's a larger difference when we use float32
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3)
|
||||
|
||||
@@ -5,10 +5,10 @@ import triton
|
||||
import triton.ops
|
||||
|
||||
|
||||
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 1024, 16),
|
||||
(4, 48, 1024, 32),
|
||||
(4, 48, 1024, 64),
|
||||
(4, 48, 1024, 128)])
|
||||
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(2, 4, 512, 16),
|
||||
(2, 4, 512, 32),
|
||||
(2, 4, 512, 64),
|
||||
(2, 4, 512, 128)])
|
||||
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize('causal', [True, False])
|
||||
@pytest.mark.parametrize('seq_par', [True, False])
|
||||
@@ -21,7 +21,8 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par):
|
||||
pytest.skip('Segmentation fault')
|
||||
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if capability[0] < 8:
|
||||
interpreter = os.environ.get("TRITON_INTERPRET", 'not found') in ["on", "true", "1"]
|
||||
if not interpreter and capability[0] < 8:
|
||||
pytest.skip("Flash attention only supported for compute capability < 80")
|
||||
torch.manual_seed(20)
|
||||
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
|
||||
|
||||
Reference in New Issue
Block a user