diff --git a/tinygrad/runtime/ops_nv.py b/tinygrad/runtime/ops_nv.py index 17dd681638..f5b7585589 100644 --- a/tinygrad/runtime/ops_nv.py +++ b/tinygrad/runtime/ops_nv.py @@ -230,7 +230,7 @@ class NVProgram(HCQProgram): # NOTE: Ensure at least 4KB of space after the program to mitigate prefetch memory faults. self.lib_gpu = self.device.allocator.alloc(round_up(image.nbytes, 0x1000) + 0x1000, BufferOptions(cpu_access=True)) - self.program_addr, self.program_sz, self.registers_usage, self.shmem_usage = self.lib_gpu.va_addr, image.nbytes, 0, 0 + self.program_addr, self.program_sz, self.registers_usage, self.shmem_usage, self.lcmem_usage = self.lib_gpu.va_addr, image.nbytes, 0, 0, 0 self.constbufs: Dict[int, Tuple[int, int]] = {0: (0, 0x160)} # Dict[constbuf index, Tuple[va_addr, size]] for sh in sections: if sh.name == f".nv.shared.{self.name}": self.shmem_usage = sh.header.sh_size @@ -240,7 +240,10 @@ class NVProgram(HCQProgram): elif sh.name == ".nv.info": for off in range(0, sh.header.sh_size, 12): typ, _, val = struct.unpack_from("III", sh.content, off) - if typ & 0xffff == 0x1204: self.device._ensure_has_local_memory(val + 0x240) + if typ & 0xffff == 0x1204: self.lcmem_usage = val + 0x240 + + # Ensure device has enough local memory to run the program + self.device._ensure_has_local_memory(self.lcmem_usage) # Apply relocs for apply_image_offset, rel_sym_offset, typ, _ in relocs: @@ -281,7 +284,8 @@ class NVProgram(HCQProgram): if hasattr(self, 'lib_gpu'): self.device.allocator.free(self.lib_gpu, self.lib_gpu.size, BufferOptions(cpu_access=True)) def __call__(self, *bufs, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False): - if prod(local_size) > 1024 or self.max_threads < prod(local_size): raise RuntimeError("Too many resources requested for launch") + if prod(local_size) > 1024 or self.max_threads < prod(local_size) or self.lcmem_usage > cast(NVDevice, self.device).slm_per_thread: + raise RuntimeError("Too many resources requested for launch") if any(cur > mx for cur,mx in zip(global_size, [2147483647, 65535, 65535])) or any(cur > mx for cur,mx in zip(local_size, [1024, 1024, 64])): raise RuntimeError(f"Invalid global/local dims {global_size=}, {local_size=}") return super().__call__(*bufs, global_size=global_size, local_size=local_size, vals=vals, wait=wait) @@ -518,7 +522,7 @@ class NVDevice(HCQCompiled): def _setup_gpfifos(self): # Set windows addresses to not collide with other allocated buffers. - self.shared_mem_window, self.local_mem_window, self.slm_per_thread = 0xfe000000, 0xff000000, 0 + self.shared_mem_window, self.local_mem_window, self.slm_per_thread, self.shader_local_mem = 0xfe000000, 0xff000000, 0, None NVComputeQueue().setup(compute_class=self.compute_class, local_mem_window=self.local_mem_window, shared_mem_window=self.shared_mem_window) \ .signal(self.timeline_signal, self.timeline_value).submit(self) @@ -532,19 +536,18 @@ class NVDevice(HCQCompiled): def _ensure_has_local_memory(self, required): if self.slm_per_thread >= required: return - if hasattr(self, 'shader_local_mem'): - self.allocator.free(self.shader_local_mem, BufferOptions(nolru=True)) # type: ignore # pylint: disable=access-member-before-definition + if self.shader_local_mem is not None: self.allocator.free(self.shader_local_mem, self.shader_local_mem.size) self.slm_per_thread, old_slm_per_thread = round_up(required, 32), self.slm_per_thread bytes_per_warp = round_up(self.slm_per_thread * 32, 0x200) bytes_per_tpc = round_up(bytes_per_warp * 48 * 2, 0x8000) - try: self.shader_local_mem = self.allocator.alloc(round_up(bytes_per_tpc * 64, 0x20000), BufferOptions(nolru=True)) + try: self.shader_local_mem = self.allocator.alloc(round_up(bytes_per_tpc * 64, 0x20000)) except MemoryError: # If can't allocate a new size, reallocator the old buffer. self.slm_per_thread = old_slm_per_thread bytes_per_tpc = round_up(round_up(self.slm_per_thread * 32, 0x200) * 48 * 2, 0x8000) - self.shader_local_mem = self.allocator.alloc(round_up(bytes_per_tpc * 64, 0x20000), BufferOptions(nolru=True)) + self.shader_local_mem = self.allocator.alloc(round_up(bytes_per_tpc * 64, 0x20000)) NVComputeQueue().wait(self.timeline_signal, self.timeline_value - 1) \ .setup(local_mem=self.shader_local_mem.va_addr, local_mem_tpc_bytes=bytes_per_tpc) \