ROCM IFU: Fix print and assert

This commit is contained in:
Jason Furmanek
2023-12-12 19:30:01 +00:00
parent 50a6db3afd
commit 160dfe838e
4 changed files with 37 additions and 30 deletions

View File

@@ -68,7 +68,8 @@ def kernel_print_no_arg():
def test_print(func: str, data_type: str):
shape = (128, )
#shape = (128, )
shape = (256, )
x = torch.arange(0, shape[0], dtype=torch.int32, device='cuda').to(getattr(torch, data_type))
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
if func == "device_print":

View File

@@ -35,22 +35,22 @@ def test_print(func_type: str, data_type: str):
# pid (<x>, <y>, <z>) idx (<i1>, <i2>, ...) <prefix> (operand <n>) <elem>
expected_lines = Counter()
if func_type == "print" or func_type == "device_print":
for i in range(128):
for i in range(256):
line = f"pid (0, 0, 0) idx ({i:3}) x: {i}"
if data_type.startswith("float"):
line += ".000000"
expected_lines[line] = 1
elif func_type == "static_print":
expected_lines[" int32[constexpr[128]]"] = 1
expected_lines[" int32[constexpr[256]]"] = 1
elif func_type == "no_arg_print":
expected_lines["pid (0, 0, 0) idx (): 0"] = 128
expected_lines["pid (0, 0, 0) idx (): 0"] = 256
elif func_type == "print_no_arg":
expected_lines["pid (0, 0, 0) no arg"] = 128
expected_lines["pid (0, 0, 0) no arg"] = 256
elif func_type == "device_print_large":
for i, j, k in itertools.product(range(2), range(64), range(128)):
expected_lines[f"pid (0, {i}, 0) idx ({j:2}, {k:3}) x: 1"] = 1
elif func_type == "print_multiple_args" or func_type == "device_print_multiple_args":
for i in range(128):
for i in range(256):
expected_lines[f"pid (0, 0, 0) idx ({i:3}): (operand 0) {i}"] = 1
expected_lines[f"pid (0, 0, 0) idx ({i:3}): (operand 1) 1"] = 1