[FRONTEND] interpreter rewrite (#2321)

This is a new interpreter mode that shares semantic analysis with the
JIT'ed codepath and that the Triton core team is committed to maintain
This commit is contained in:
Philippe Tillet
2023-09-17 14:58:50 -07:00
committed by GitHub
parent 2b066000aa
commit e686b4d6d4
17 changed files with 599 additions and 1033 deletions

View File

@@ -1,69 +0,0 @@
import random
import torch
import triton
import triton.language as tl
from triton.interpreter.interpreter import program_ids_from_grid
def test_addition():
@triton.jit(interpret=True)
def add_kernel(
x_ptr,
y_ptr,
output_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
tl.store(output_ptr + offsets, output, mask=mask)
a = torch.rand((128,), device="cuda")
b = torch.rand((128,), device="cuda")
expected = a + b
output = torch.empty((128,), device="cuda")
def grid(meta):
return (triton.cdiv(128, meta["BLOCK_SIZE"]),)
add_kernel[grid](a, b, output, 128, BLOCK_SIZE=32)
assert torch.allclose(expected, output, atol=1e-2, rtol=0)
def test_program_ids_from_grid():
random.seed(123)
grid = (3, 4)
expected_combinations = 3 * 4
unique_combinations = set(program_ids_from_grid(grid))
assert len(unique_combinations) == expected_combinations
first_run = list(program_ids_from_grid(grid))
second_run = list(program_ids_from_grid(grid))
assert first_run != second_run
def test_atomic():
@triton.jit(interpret=True)
def atomic(
x_ptr,
):
pid = tl.program_id(axis=0)
tl.atomic_add(x_ptr + pid, 1)
t = tl.atomic_xchg(x_ptr + pid, 3)
t += 1 # 2
tl.atomic_cas(x_ptr + pid, 3, t) # match
tl.atomic_cas(x_ptr + pid, 40, 9) # no match
nb_dim = 16
a = torch.zeros((nb_dim, ), dtype=torch.int32, device="cuda")
atomic[(nb_dim, )](a)
assert torch.allclose(a, torch.full_like(a, 2))

View File

@@ -2421,7 +2421,6 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
if epilogue == 'chain-dot':
z_ref = np.matmul(z_ref, w)
# compare
# print(z_ref[:,0], z_tri[:,0])
if in_dtype == 'float32':
# XXX: Somehow there's a larger difference when we use float32
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3)

View File

@@ -5,10 +5,10 @@ import triton
import triton.ops
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 1024, 16),
(4, 48, 1024, 32),
(4, 48, 1024, 64),
(4, 48, 1024, 128)])
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(2, 4, 512, 16),
(2, 4, 512, 32),
(2, 4, 512, 64),
(2, 4, 512, 128)])
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
@pytest.mark.parametrize('causal', [True, False])
@pytest.mark.parametrize('seq_par', [True, False])
@@ -21,7 +21,8 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par):
pytest.skip('Segmentation fault')
capability = torch.cuda.get_device_capability()
if capability[0] < 8:
interpreter = os.environ.get("TRITON_INTERPRET", 'not found') in ["on", "true", "1"]
if not interpreter and capability[0] < 8:
pytest.skip("Flash attention only supported for compute capability < 80")
torch.manual_seed(20)
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()