[BACKEND] Fix device_print without arguments (#2566)

This commit is contained in:
Keren Zhou
2023-10-30 20:04:44 -04:00
committed by GitHub
parent 492886fcde
commit 70fca00b67
4 changed files with 45 additions and 34 deletions

View File

@@ -62,6 +62,11 @@ def kernel_no_arg_print():
print("", tl.program_id(0))
@triton.jit
def kernel_print_no_arg():
print("no arg")
def test_print(func: str, data_type: str):
shape = (128, )
x = torch.arange(0, shape[0], dtype=torch.int32, device='cuda').to(getattr(torch, data_type))
@@ -80,10 +85,12 @@ def test_print(func: str, data_type: str):
kernel_static_print[(1,)](x, y, BLOCK=shape[0], PLACEHOLDER=uuid.uuid4())
elif func == "no_arg_print":
kernel_no_arg_print[(1,)](num_warps=4)
elif func == "print_no_arg":
kernel_print_no_arg[(1,)](num_warps=4)
else:
assert f"Unknown kernel: {func}"
if func != "no_arg_print" and func != "device_print_large" and \
if func != "print_no_arg" and func != "no_arg_print" and func != "device_print_large" and \
func != "print_multiple_args" and func != "device_print_multiple_args":
assert_close(y, x)

View File

@@ -22,6 +22,7 @@ torch_types = ["int8", "uint8", "int16", "int32", "long", "float16", "float32",
("print", "int32"),
("static_print", "int32"),
("no_arg_print", "int32"),
("print_no_arg", "int32"),
("device_print_large", "int32"),
("print_multiple_args", "int32"),
("device_print_multiple_args", "int32"),
@@ -44,6 +45,8 @@ def test_print(func_type: str, data_type: str):
expected_lines[" int32[constexpr[128]]"] = 1
elif func_type == "no_arg_print":
expected_lines["pid (0, 0, 0) idx (): 0"] = 128
elif func_type == "print_no_arg":
expected_lines["pid (0, 0, 0) no arg"] = 128
elif func_type == "device_print_large":
for i, j, k in itertools.product(range(2), range(64), range(128)):
expected_lines[f"pid (0, {i}, 0) idx ({j:2}, {k:3}) x: 1"] = 1

View File

@@ -1579,9 +1579,9 @@ def debug_barrier(builder: ir.builder) -> tl.tensor:
def device_print(prefix: str, args: List[tl.tensor], builder: ir.builder) -> tl.tensor:
# It makes sense visually for prefix to end in ": "; make it so. Also,
# non-empty prefixes should start with " ".
if not prefix.endswith(" "):
if not prefix.endswith(" ") and args:
prefix += " "
if not prefix.endswith(": "):
if not prefix.endswith(": ") and args:
prefix = prefix[:-1] + ": "
if len(prefix) > 2 and not prefix.startswith(" "):
prefix = " " + prefix