[FRONTEND] Fix specialization on triton integer types (#2236)

https://github.com/openai/triton/issues/2231
This commit is contained in:
Keren Zhou
2023-09-04 02:57:08 -04:00
committed by GitHub
parent a539836876
commit 9e9fbe01f0
2 changed files with 21 additions and 5 deletions

View File

@@ -105,6 +105,21 @@ def test_specialize(mode):
assert counter == target
def test_annotation():
@triton.jit
def kernel(X, i: tl.int32):
tl.store(X, i)
x = torch.empty(1, dtype=torch.int32, device='cuda')
device = torch.cuda.current_device()
kernel[(1,)](x, 1)
kernel[(1,)](x, 8)
kernel[(1,)](x, 16)
kernel[(1,)](x, 17)
assert len(kernel.cache[device]) == 4
def test_constexpr_not_callable() -> None:
@triton.jit
def kernel(X, c: tl.constexpr):
@@ -138,13 +153,14 @@ def test_jit_warmup_cache() -> None:
torch.randn(32, dtype=torch.float32, device="cuda"),
32,
]
assert len(kernel_add.cache) == 0
device = torch.cuda.current_device()
assert len(kernel_add.cache[device]) == 0
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,))
assert len(kernel_add.cache) == 1
assert len(kernel_add.cache[device]) == 1
kernel_add.warmup(*args, grid=(1,))
assert len(kernel_add.cache) == 1
assert len(kernel_add.cache[device]) == 1
kernel_add.warmup(*args, grid=(1,))
assert len(kernel_add.cache) == 1
assert len(kernel_add.cache[device]) == 1
def test_jit_debug() -> None:

View File

@@ -306,7 +306,7 @@ class JITFunction(KernelInterface[T]):
else (False,)'
elif 'Tensor' in arg_annotation:
return f'({arg}.data_ptr() % {JITFunction.divisibility} == 0)'
elif arg_annotation == 'int':
elif 'int' in arg_annotation or 'bool' in arg_annotation:
return f'({arg} % {JITFunction.divisibility} == 0, {arg} % {JITFunction.divisibility_8} == 0, {arg} == 1)'
else:
return '(False,)'