mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[TESTING] Added more tests for annotations and autotuner (#1533)
Essentially identical to #538, but it fails formatting tests and I don't want to ping the author on a weekend.
This commit is contained in:
21
python/test/unit/language/test_annotations.py
Normal file
21
python/test/unit/language/test_annotations.py
Normal file
@@ -0,0 +1,21 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
def test_annotations():
|
||||
|
||||
@triton.jit
|
||||
def _kernel(X: torch.Tensor, N: int, BLOCK_SIZE: tl.constexpr):
|
||||
pass
|
||||
|
||||
x = torch.empty(1, device='cuda')
|
||||
_kernel[(1,)](x, x.shape[0], 32)
|
||||
try:
|
||||
_kernel[(1,)](x.shape[0], x.shape[0], 32)
|
||||
except AttributeError:
|
||||
pass
|
||||
22
python/test/unit/runtime/test_autotuner.py
Normal file
22
python/test/unit/runtime/test_autotuner.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
def test_kwargs():
|
||||
N = 1024
|
||||
src = torch.empty(N, device='cuda')
|
||||
dst = torch.empty(N, device='cuda')
|
||||
|
||||
configs = [triton.Config(kwargs={'BLOCK_SIZE': 32}), triton.Config(kwargs={'BLOCK_SIZE': 128})]
|
||||
|
||||
@triton.autotune(configs=configs, key=['N'])
|
||||
@triton.jit
|
||||
def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr):
|
||||
offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
x = tl.load(src + offsets, mask=offsets < N)
|
||||
tl.store(dst + offsets, x, mask=offsets < N)
|
||||
grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']),)
|
||||
_kernel[grid](dst, src, N)
|
||||
_kernel[grid](dst=dst, src=src, N=N)
|
||||
Reference in New Issue
Block a user