mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND][BACKEND] Implement tl.device_assert and rename tl.printf to tl.device_print (#1143)
Note that `tl.device_print` and `print` accepts different arguments than
the normal `print`. The first argument must be a string, following by
variables.
Device side:
- `tl.device_print`
- `tl.device_assert`
- `print`
- `assert`
Compilation time:
- `tl.static_assert`
- `tl.static_print`
Usage example:
1.
```Python
tl.device_assert(x == 0, "x != 0")
```
Output:
```Python
...
python/test/unit/language/assert_helper.py:18: kernel: block: [0,0,0], thread: [33,0,0] Assertion `x != 0` failed.
...
```
2.
```Python
tl.device_print("hello ", x)
```
Output:
```Python
...
hello 1
...
```
The environment variable `TRITON_DEBUG` sets the default debugging flag; if it's true, `tl.device_assert` or `assert` will be skipped.
This commit is contained in:
45
python/test/unit/language/assert_helper.py
Normal file
45
python/test/unit/language/assert_helper.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from torch.testing import assert_close
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel_device_assert(X, Y, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
tl.device_assert(x == 0, "x != 0")
|
||||
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel_assert(X, Y, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
assert x == 0, "x != 0"
|
||||
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel_static_assert(X, Y, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
tl.static_assert(BLOCK == 128, "BLOCK != 128")
|
||||
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||
|
||||
|
||||
def test_assert(func: str):
|
||||
shape = (128, )
|
||||
x = torch.arange(0, shape[0], dtype=torch.int32, device='cuda')
|
||||
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
|
||||
if func == "device_assert":
|
||||
kernel_device_assert[(1,)](x, y, BLOCK=shape[0])
|
||||
elif func == "assert":
|
||||
kernel_assert[(1,)](x, y, BLOCK=shape[0])
|
||||
elif func == "static_assert":
|
||||
kernel_static_assert[(1,)](x, y, BLOCK=shape[0])
|
||||
assert_close(y, x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_assert(sys.argv[1])
|
||||
46
python/test/unit/language/print_helper.py
Normal file
46
python/test/unit/language/print_helper.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from torch.testing import assert_close
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel_device_print(X, Y, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
tl.device_print("", x)
|
||||
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel_print(X, Y, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
print("", x)
|
||||
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel_static_print(X, Y, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
tl.static_print(x)
|
||||
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||
|
||||
|
||||
def test_print(func: str, data_type: str):
|
||||
shape = (128, )
|
||||
# limit the range of integers so that the sum does not overflow
|
||||
x = torch.arange(0, shape[0], dtype=torch.int32, device='cuda').to(getattr(torch, data_type))
|
||||
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
|
||||
if func == "device_print":
|
||||
kernel_device_print[(1,)](x, y, BLOCK=shape[0])
|
||||
elif func == "print":
|
||||
kernel_print[(1,)](x, y, BLOCK=shape[0])
|
||||
elif func == "static_print":
|
||||
kernel_static_print[(1,)](x, y, BLOCK=shape[0])
|
||||
assert_close(y, x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_print(sys.argv[1], sys.argv[2])
|
||||
@@ -1,56 +0,0 @@
|
||||
import torch
|
||||
from torch.testing import assert_close
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
torch_type = {
|
||||
"bool": torch.bool,
|
||||
'int8': torch.int8,
|
||||
'uint8': torch.uint8,
|
||||
'int16': torch.int16,
|
||||
"int32": torch.int32,
|
||||
'int64': torch.long,
|
||||
'float16': torch.float16,
|
||||
'bfloat16': torch.bfloat16,
|
||||
"float32": torch.float32,
|
||||
"float64": torch.float64
|
||||
}
|
||||
|
||||
|
||||
def get_tensor(shape, data_type, b_positive=False):
|
||||
x = None
|
||||
if data_type.startswith('int'):
|
||||
x = torch.arange(0, shape[0], dtype=torch_type[data_type], device='cuda')
|
||||
else:
|
||||
x = torch.arange(0, shape[0], dtype=torch_type[data_type], device='cuda')
|
||||
|
||||
return x
|
||||
|
||||
# @pytest.mark.parametrize('data_type',
|
||||
# [("int8"),
|
||||
# ('int16'),
|
||||
# ('int32'),
|
||||
# ("int64"),
|
||||
# ('float16'),
|
||||
# ("float32"),
|
||||
# ("float64")])
|
||||
|
||||
|
||||
def printf(data_type):
|
||||
@triton.jit
|
||||
def kernel(X, Y, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
tl.printf("", x)
|
||||
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||
|
||||
shape = (128, )
|
||||
# limit the range of integers so that the sum does not overflow
|
||||
x = get_tensor(shape, data_type)
|
||||
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
|
||||
kernel[(1,)](x, y, BLOCK=shape[0])
|
||||
assert_close(y, x)
|
||||
|
||||
|
||||
printf("float16")
|
||||
printf("int8")
|
||||
@@ -1,22 +0,0 @@
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
dir_path = os.path.dirname(os.path.realpath(__file__))
|
||||
printf_path = os.path.join(dir_path, "printf_helper.py")
|
||||
|
||||
|
||||
def test_printf():
|
||||
proc = subprocess.Popen([sys.executable, printf_path], stdout=subprocess.PIPE, shell=False)
|
||||
(outs, err) = proc.communicate()
|
||||
outs = outs.split()
|
||||
new_lines = set()
|
||||
for line in outs:
|
||||
try:
|
||||
value = int(float(line))
|
||||
new_lines.add(value)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
for i in range(128):
|
||||
assert i in new_lines
|
||||
assert len(new_lines) == 128
|
||||
53
python/test/unit/language/test_subprocess.py
Normal file
53
python/test/unit/language/test_subprocess.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
dir_path = os.path.dirname(os.path.realpath(__file__))
|
||||
print_path = os.path.join(dir_path, "print_helper.py")
|
||||
assert_path = os.path.join(dir_path, "assert_helper.py")
|
||||
|
||||
# TODO: bfloat16 after LLVM-15
|
||||
func_types = ["device_assert", "assert", "static_assert"]
|
||||
torch_types = ["int8", "uint8", "int16", "int32", "long", "float16", "float32", "float64"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("func_type, data_type",
|
||||
[("device_print", data_type) for data_type in torch_types] + [("print", "int32"), ("static_print", "int32")])
|
||||
def test_print(func_type: str, data_type: str):
|
||||
proc = subprocess.Popen([sys.executable, print_path, func_type, data_type], stdout=subprocess.PIPE, shell=False)
|
||||
outs, _ = proc.communicate()
|
||||
outs = outs.split()
|
||||
new_lines = set()
|
||||
for line in outs:
|
||||
try:
|
||||
value = line
|
||||
if func_type != "static_print":
|
||||
value = int(float(line))
|
||||
new_lines.add(value)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
if func_type != "static_print":
|
||||
for i in range(128):
|
||||
assert i in new_lines
|
||||
assert len(new_lines) == 128
|
||||
else:
|
||||
assert len(new_lines) == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("func_type", func_types)
|
||||
def test_assert(func_type: str):
|
||||
os.environ["TRITON_DEBUG"] = "1"
|
||||
proc = subprocess.Popen([sys.executable, assert_path, func_type], stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=False)
|
||||
_, errs = proc.communicate()
|
||||
errs = errs.splitlines()
|
||||
num_errs = 0
|
||||
for err in errs:
|
||||
if "x != 0" in err.decode("utf-8"):
|
||||
num_errs += 1
|
||||
os.environ["TRITON_DEBUG"] = "0"
|
||||
if func_type != "static_assert":
|
||||
assert num_errs == 127
|
||||
else:
|
||||
assert num_errs == 0
|
||||
@@ -148,6 +148,26 @@ def test_jit_warmup_cache() -> None:
|
||||
assert len(kernel_add.cache) == 1
|
||||
|
||||
|
||||
def test_jit_debug() -> None:
|
||||
@triton.jit
|
||||
def kernel_add(a, b, o, N: tl.constexpr):
|
||||
idx = tl.arange(0, N)
|
||||
tl.device_assert(idx < 32, "idx < 32")
|
||||
tl.store(o + idx,
|
||||
tl.load(a + idx) + tl.load(b + idx))
|
||||
|
||||
device = torch.cuda.current_device()
|
||||
assert len(kernel_add.cache[device]) == 0
|
||||
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,))
|
||||
assert len(kernel_add.cache[device]) == 1
|
||||
kernel_add.debug = False
|
||||
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,))
|
||||
assert len(kernel_add.cache[device]) == 1
|
||||
kernel_add.debug = True
|
||||
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,))
|
||||
assert len(kernel_add.cache[device]) == 2
|
||||
|
||||
|
||||
def test_compile_in_subproc() -> None:
|
||||
@triton.jit
|
||||
def kernel_sub(a, b, o, N: tl.constexpr):
|
||||
|
||||
Reference in New Issue
Block a user