include dims in kernel/nv invalid err msg (#5487)

This commit is contained in:
chenyu
2024-07-14 22:51:30 -04:00
committed by GitHub
parent c80801c266
commit eef43c9f49
2 changed files with 3 additions and 3 deletions

View File

@@ -674,8 +674,8 @@ class Kernel:
def fix_st(warp_dims, tcd_dims, tcd_expand, pattern_1, pattern_2, st1):
wd = self.global_dims
tcd = self.shape_len-self.upcasted
assert st1.shape[wd:wd+len(warp_dims)] == warp_dims, "warp dims wrong"
assert st1.shape[tcd:tcd+len(tcd_dims)] == tcd_dims, "tcd dims wrong"
assert st1.shape[wd:wd+len(warp_dims)] == warp_dims, f"warp dims wrong: {st1.shape[wd:wd+len(warp_dims)]=} != {warp_dims=}"
assert st1.shape[tcd:tcd+len(tcd_dims)] == tcd_dims, f"tcd dims wrong: {st1.shape[tcd:tcd+len(tcd_dims)]=} != {tcd_dims=}"
new_shape = st1.shape[:tcd] + tcd_expand + st1.shape[tcd+len(tcd_dims):] # expand the tcd
permaxis = list(range(wd))
for x,y in pattern_1: permaxis.append(y + (wd if x == 0 else tcd))

View File

@@ -322,7 +322,7 @@ class NVProgram(HCQCompatProgram):
def __call__(self, *args, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False):
if prod(local_size) > 1024 or self.max_threads < prod(local_size): raise RuntimeError("Too many resources requsted for launch")
if any(cur > mx for cur,mx in zip(global_size, [2147483647, 65535, 65535])) or any(cur > mx for cur,mx in zip(local_size, [1024, 1024, 64])):
raise RuntimeError("Invalid global/local dims")
raise RuntimeError(f"Invalid global/local dims {global_size=}, {local_size=}")
if self.device.kernargs_ptr >= (self.device.kernargs_page.va_addr + self.device.kernargs_page.size - self.kernargs_alloc_size):
self.device.kernargs_ptr = self.device.kernargs_page.va_addr