mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
fix kernargs check in kfd (#4194)
This commit is contained in:
@@ -225,9 +225,9 @@ class KFDProgram:
|
||||
if hasattr(self, 'lib_gpu'): self.device._gpu_free(self.lib_gpu)
|
||||
|
||||
def __call__(self, *args, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False):
|
||||
if self.device.kernargs_ptr >= (self.device.kernargs.va_addr + self.device.kernargs.size + self.kernargs_segment_size):
|
||||
if self.device.kernargs_ptr + self.kernargs_segment_size > (self.device.kernargs.va_addr + self.device.kernargs.size):
|
||||
self.device.kernargs_ptr = self.device.kernargs.va_addr
|
||||
assert self.device.kernargs_ptr < (self.device.kernargs.va_addr + self.device.kernargs.size + self.kernargs_segment_size), "kernargs overrun"
|
||||
assert self.device.kernargs_ptr + self.kernargs_segment_size <= (self.device.kernargs.va_addr + self.device.kernargs.size), "kernargs overrun"
|
||||
if not hasattr(self, "args_struct_t"):
|
||||
self.args_struct_t = init_c_struct_t(tuple([(f'f{i}', ctypes.c_void_p) for i in range(len(args))] +
|
||||
[(f'v{i}', ctypes.c_int) for i in range(len(vals))]))
|
||||
|
||||
Reference in New Issue
Block a user