mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Make kernel_static_print test work when called twice. (#2518)
<git-pr-chain>
#### Commits in this PR
1. Make kernel_static_print test work when called twice.
This test is checking that a message is printed when the kernel is
compiled.
But the test had nothing to force the kernel to be compiled every time
you ran
the test. So after you ran it once, the test would fail every time until
you
cleared the cache.
#### [PR chain](https://github.com/jlebar/git-pr-chain)
1. 👉 #2518 👈 **YOU ARE HERE**
1. #2520
</git-pr-chain>
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import sys
|
||||
import uuid
|
||||
|
||||
import torch
|
||||
from torch.testing import assert_close
|
||||
@@ -21,10 +22,12 @@ def kernel_print(X, Y, BLOCK: tl.constexpr):
|
||||
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||
|
||||
|
||||
# Take an extra value as a tl.constexpr so this kernel is not cached. This way
|
||||
# the static print is run every time.
|
||||
@triton.jit
|
||||
def kernel_static_print(X, Y, BLOCK: tl.constexpr):
|
||||
def kernel_static_print(X, Y, BLOCK: tl.constexpr, PLACEHOLDER: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
tl.static_print(x)
|
||||
tl.static_print("", x)
|
||||
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||
|
||||
|
||||
@@ -43,7 +46,7 @@ def test_print(func: str, data_type: str):
|
||||
elif func == "print":
|
||||
kernel_print[(1,)](x, y, BLOCK=shape[0])
|
||||
elif func == "static_print":
|
||||
kernel_static_print[(1,)](x, y, BLOCK=shape[0])
|
||||
kernel_static_print[(1,)](x, y, BLOCK=shape[0], PLACEHOLDER=uuid.uuid4())
|
||||
elif func == "no_arg_print":
|
||||
kernel_no_arg_print[(1,)](num_warps=4)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user