mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Fix device_print without arguments (#2566)
This commit is contained in:
@@ -62,6 +62,11 @@ def kernel_no_arg_print():
|
||||
print("", tl.program_id(0))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel_print_no_arg():
|
||||
print("no arg")
|
||||
|
||||
|
||||
def test_print(func: str, data_type: str):
|
||||
shape = (128, )
|
||||
x = torch.arange(0, shape[0], dtype=torch.int32, device='cuda').to(getattr(torch, data_type))
|
||||
@@ -80,10 +85,12 @@ def test_print(func: str, data_type: str):
|
||||
kernel_static_print[(1,)](x, y, BLOCK=shape[0], PLACEHOLDER=uuid.uuid4())
|
||||
elif func == "no_arg_print":
|
||||
kernel_no_arg_print[(1,)](num_warps=4)
|
||||
elif func == "print_no_arg":
|
||||
kernel_print_no_arg[(1,)](num_warps=4)
|
||||
else:
|
||||
assert f"Unknown kernel: {func}"
|
||||
|
||||
if func != "no_arg_print" and func != "device_print_large" and \
|
||||
if func != "print_no_arg" and func != "no_arg_print" and func != "device_print_large" and \
|
||||
func != "print_multiple_args" and func != "device_print_multiple_args":
|
||||
assert_close(y, x)
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ torch_types = ["int8", "uint8", "int16", "int32", "long", "float16", "float32",
|
||||
("print", "int32"),
|
||||
("static_print", "int32"),
|
||||
("no_arg_print", "int32"),
|
||||
("print_no_arg", "int32"),
|
||||
("device_print_large", "int32"),
|
||||
("print_multiple_args", "int32"),
|
||||
("device_print_multiple_args", "int32"),
|
||||
@@ -44,6 +45,8 @@ def test_print(func_type: str, data_type: str):
|
||||
expected_lines[" int32[constexpr[128]]"] = 1
|
||||
elif func_type == "no_arg_print":
|
||||
expected_lines["pid (0, 0, 0) idx (): 0"] = 128
|
||||
elif func_type == "print_no_arg":
|
||||
expected_lines["pid (0, 0, 0) no arg"] = 128
|
||||
elif func_type == "device_print_large":
|
||||
for i, j, k in itertools.product(range(2), range(64), range(128)):
|
||||
expected_lines[f"pid (0, {i}, 0) idx ({j:2}, {k:3}) x: 1"] = 1
|
||||
|
||||
@@ -1579,9 +1579,9 @@ def debug_barrier(builder: ir.builder) -> tl.tensor:
|
||||
def device_print(prefix: str, args: List[tl.tensor], builder: ir.builder) -> tl.tensor:
|
||||
# It makes sense visually for prefix to end in ": "; make it so. Also,
|
||||
# non-empty prefixes should start with " ".
|
||||
if not prefix.endswith(" "):
|
||||
if not prefix.endswith(" ") and args:
|
||||
prefix += " "
|
||||
if not prefix.endswith(": "):
|
||||
if not prefix.endswith(": ") and args:
|
||||
prefix = prefix[:-1] + ": "
|
||||
if len(prefix) > 2 and not prefix.startswith(" "):
|
||||
prefix = " " + prefix
|
||||
|
||||
Reference in New Issue
Block a user