mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] Fix specialization on triton integer types (#2236)
https://github.com/openai/triton/issues/2231
This commit is contained in:
@@ -105,6 +105,21 @@ def test_specialize(mode):
|
||||
assert counter == target
|
||||
|
||||
|
||||
def test_annotation():
|
||||
@triton.jit
|
||||
def kernel(X, i: tl.int32):
|
||||
tl.store(X, i)
|
||||
|
||||
x = torch.empty(1, dtype=torch.int32, device='cuda')
|
||||
|
||||
device = torch.cuda.current_device()
|
||||
kernel[(1,)](x, 1)
|
||||
kernel[(1,)](x, 8)
|
||||
kernel[(1,)](x, 16)
|
||||
kernel[(1,)](x, 17)
|
||||
assert len(kernel.cache[device]) == 4
|
||||
|
||||
|
||||
def test_constexpr_not_callable() -> None:
|
||||
@triton.jit
|
||||
def kernel(X, c: tl.constexpr):
|
||||
@@ -138,13 +153,14 @@ def test_jit_warmup_cache() -> None:
|
||||
torch.randn(32, dtype=torch.float32, device="cuda"),
|
||||
32,
|
||||
]
|
||||
assert len(kernel_add.cache) == 0
|
||||
device = torch.cuda.current_device()
|
||||
assert len(kernel_add.cache[device]) == 0
|
||||
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,))
|
||||
assert len(kernel_add.cache) == 1
|
||||
assert len(kernel_add.cache[device]) == 1
|
||||
kernel_add.warmup(*args, grid=(1,))
|
||||
assert len(kernel_add.cache) == 1
|
||||
assert len(kernel_add.cache[device]) == 1
|
||||
kernel_add.warmup(*args, grid=(1,))
|
||||
assert len(kernel_add.cache) == 1
|
||||
assert len(kernel_add.cache[device]) == 1
|
||||
|
||||
|
||||
def test_jit_debug() -> None:
|
||||
|
||||
Reference in New Issue
Block a user