Make kernel_static_print test work when called twice. (#2518)

<git-pr-chain>

#### Commits in this PR
1. Make kernel_static_print test work when called twice.
    
This test is checking that a message is printed when the kernel is
compiled.
But the test had nothing to force the kernel to be compiled every time
you ran
the test. So after you ran it once, the test would fail every time until
you
    cleared the cache.

#### [PR chain](https://github.com/jlebar/git-pr-chain)
1. 👉 #2518 👈 **YOU ARE HERE**
1. #2520


</git-pr-chain>
This commit is contained in:
Justin Lebar
2023-10-19 13:17:38 -07:00
committed by GitHub
parent 0d57820be9
commit bdf464e4a8

View File

@@ -1,4 +1,5 @@
import sys
import uuid
import torch
from torch.testing import assert_close
@@ -21,10 +22,12 @@ def kernel_print(X, Y, BLOCK: tl.constexpr):
tl.store(Y + tl.arange(0, BLOCK), x)
# Take an extra value as a tl.constexpr so this kernel is not cached. This way
# the static print is run every time.
@triton.jit
def kernel_static_print(X, Y, BLOCK: tl.constexpr):
def kernel_static_print(X, Y, BLOCK: tl.constexpr, PLACEHOLDER: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
tl.static_print(x)
tl.static_print("", x)
tl.store(Y + tl.arange(0, BLOCK), x)
@@ -43,7 +46,7 @@ def test_print(func: str, data_type: str):
elif func == "print":
kernel_print[(1,)](x, y, BLOCK=shape[0])
elif func == "static_print":
kernel_static_print[(1,)](x, y, BLOCK=shape[0])
kernel_static_print[(1,)](x, y, BLOCK=shape[0], PLACEHOLDER=uuid.uuid4())
elif func == "no_arg_print":
kernel_no_arg_print[(1,)](num_warps=4)