mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
include dims in kernel/nv invalid err msg (#5487)
This commit is contained in:
@@ -674,8 +674,8 @@ class Kernel:
|
||||
def fix_st(warp_dims, tcd_dims, tcd_expand, pattern_1, pattern_2, st1):
|
||||
wd = self.global_dims
|
||||
tcd = self.shape_len-self.upcasted
|
||||
assert st1.shape[wd:wd+len(warp_dims)] == warp_dims, "warp dims wrong"
|
||||
assert st1.shape[tcd:tcd+len(tcd_dims)] == tcd_dims, "tcd dims wrong"
|
||||
assert st1.shape[wd:wd+len(warp_dims)] == warp_dims, f"warp dims wrong: {st1.shape[wd:wd+len(warp_dims)]=} != {warp_dims=}"
|
||||
assert st1.shape[tcd:tcd+len(tcd_dims)] == tcd_dims, f"tcd dims wrong: {st1.shape[tcd:tcd+len(tcd_dims)]=} != {tcd_dims=}"
|
||||
new_shape = st1.shape[:tcd] + tcd_expand + st1.shape[tcd+len(tcd_dims):] # expand the tcd
|
||||
permaxis = list(range(wd))
|
||||
for x,y in pattern_1: permaxis.append(y + (wd if x == 0 else tcd))
|
||||
|
||||
@@ -322,7 +322,7 @@ class NVProgram(HCQCompatProgram):
|
||||
def __call__(self, *args, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False):
|
||||
if prod(local_size) > 1024 or self.max_threads < prod(local_size): raise RuntimeError("Too many resources requsted for launch")
|
||||
if any(cur > mx for cur,mx in zip(global_size, [2147483647, 65535, 65535])) or any(cur > mx for cur,mx in zip(local_size, [1024, 1024, 64])):
|
||||
raise RuntimeError("Invalid global/local dims")
|
||||
raise RuntimeError(f"Invalid global/local dims {global_size=}, {local_size=}")
|
||||
|
||||
if self.device.kernargs_ptr >= (self.device.kernargs_page.va_addr + self.device.kernargs_page.size - self.kernargs_alloc_size):
|
||||
self.device.kernargs_ptr = self.device.kernargs_page.va_addr
|
||||
|
||||
Reference in New Issue
Block a user