[TESTING] Added more tests for annotations and autotuner (#1533)

Essentially identical to #538, but it fails formatting tests and I don't
want to ping the author on a weekend.
This commit is contained in:
Philippe Tillet
2023-04-15 19:44:08 -07:00
committed by GitHub
parent df6c2babbd
commit 608ec061c1
2 changed files with 43 additions and 0 deletions

View File

@@ -0,0 +1,21 @@
from __future__ import annotations
import torch
import triton
import triton.language as tl
def test_annotations():
@triton.jit
def _kernel(X: torch.Tensor, N: int, BLOCK_SIZE: tl.constexpr):
pass
x = torch.empty(1, device='cuda')
_kernel[(1,)](x, x.shape[0], 32)
try:
_kernel[(1,)](x.shape[0], x.shape[0], 32)
except AttributeError:
pass

View File

@@ -0,0 +1,22 @@
import torch
import triton
import triton.language as tl
def test_kwargs():
N = 1024
src = torch.empty(N, device='cuda')
dst = torch.empty(N, device='cuda')
configs = [triton.Config(kwargs={'BLOCK_SIZE': 32}), triton.Config(kwargs={'BLOCK_SIZE': 128})]
@triton.autotune(configs=configs, key=['N'])
@triton.jit
def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr):
offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
x = tl.load(src + offsets, mask=offsets < N)
tl.store(dst + offsets, x, mask=offsets < N)
grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']),)
_kernel[grid](dst, src, N)
_kernel[grid](dst=dst, src=src, N=N)