mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-19 02:44:40 -05:00
nv correct local memory based on device (#7307)
* nv correct local memory based on device * linter * oops * oops2
This commit is contained in:
@@ -486,9 +486,9 @@ class NVDevice(HCQCompiled):
|
||||
self.cmdq: memoryview = to_mv(self.cmdq_page.va_addr, 0x200000).cast("I")
|
||||
self.cmdq_wptr: int = 0 # in bytes
|
||||
|
||||
sm_info = nv_gpu.NV2080_CTRL_GR_INFO(index=nv_gpu.NV2080_CTRL_GR_INFO_INDEX_SM_VERSION)
|
||||
rmctrl.gr_get_info(self.fd_ctl, self.root, self.subdevice, grInfoListSize=1, grInfoList=ctypes.addressof(sm_info))
|
||||
self.arch: str = f"sm_{(sm_info.data>>8)&0xff}{(val>>4) if (val:=sm_info.data&0xff) > 0xf else val}"
|
||||
self.num_gpcs, self.num_tpc_per_gpc, self.num_sm_per_tpc, self.max_warps_per_sm, self.sm_version = self._query_gpu_info('num_gpcs',
|
||||
'num_tpc_per_gpc', 'num_sm_per_tpc', 'max_warps_per_sm', 'sm_version')
|
||||
self.arch: str = f"sm_{(self.sm_version>>8)&0xff}{(val>>4) if (val:=self.sm_version&0xff) > 0xf else val}"
|
||||
|
||||
compiler_t = (PTXCompiler if PTX else CUDACompiler) if MOCKGPU else (NVPTXCompiler if PTX else NVCompiler)
|
||||
super().__init__(device, NVAllocator(self), PTXRenderer(self.arch, device="NV") if PTX else NVRenderer(self.arch), compiler_t(self.arch),
|
||||
@@ -520,6 +520,12 @@ class NVDevice(HCQCompiled):
|
||||
return GPFifo(ring=to_mv(gpfifo_area.va_addr + offset, entries * 8).cast("Q"), entries_count=entries, token=ws_token_params.workSubmitToken,
|
||||
controls=nv_gpu.AmpereAControlGPFifo.from_address(gpfifo_area.va_addr + offset + entries * 8))
|
||||
|
||||
def _query_gpu_info(self, *reqs):
|
||||
nvrs = [getattr(nv_gpu,'NV2080_CTRL_GR_INFO_INDEX_'+r.upper(), getattr(nv_gpu,'NV2080_CTRL_GR_INFO_INDEX_LITTER_'+r.upper(),None)) for r in reqs]
|
||||
infos = (nv_gpu.NV2080_CTRL_GR_INFO*len(nvrs))(*[nv_gpu.NV2080_CTRL_GR_INFO(index=nvr) for nvr in nvrs])
|
||||
rmctrl.gr_get_info(self.fd_ctl, self.root, self.subdevice, grInfoListSize=len(infos), grInfoList=ctypes.addressof(infos))
|
||||
return [x.data for x in infos]
|
||||
|
||||
def _setup_gpfifos(self):
|
||||
# Set windows addresses to not collide with other allocated buffers.
|
||||
self.shared_mem_window, self.local_mem_window, self.slm_per_thread, self.shader_local_mem = 0xfe000000, 0xff000000, 0, None
|
||||
@@ -539,15 +545,14 @@ class NVDevice(HCQCompiled):
|
||||
if self.shader_local_mem is not None: self.allocator.free(self.shader_local_mem, self.shader_local_mem.size)
|
||||
|
||||
self.slm_per_thread, old_slm_per_thread = round_up(required, 32), self.slm_per_thread
|
||||
bytes_per_warp = round_up(self.slm_per_thread * 32, 0x200)
|
||||
bytes_per_tpc = round_up(bytes_per_warp * 48 * 2, 0x8000)
|
||||
bytes_per_tpc = round_up(round_up(self.slm_per_thread * 32, 0x200) * self.max_warps_per_sm * self.num_sm_per_tpc, 0x8000)
|
||||
|
||||
try: self.shader_local_mem = self.allocator.alloc(round_up(bytes_per_tpc * 64, 0x20000))
|
||||
try: self.shader_local_mem = self.allocator.alloc(round_up(bytes_per_tpc * self.num_tpc_per_gpc * self.num_gpcs, 0x20000))
|
||||
except MemoryError:
|
||||
# If can't allocate a new size, reallocator the old buffer.
|
||||
self.slm_per_thread = old_slm_per_thread
|
||||
bytes_per_tpc = round_up(round_up(self.slm_per_thread * 32, 0x200) * 48 * 2, 0x8000)
|
||||
self.shader_local_mem = self.allocator.alloc(round_up(bytes_per_tpc * 64, 0x20000))
|
||||
bytes_per_tpc = round_up(round_up(self.slm_per_thread * 32, 0x200) * self.max_warps_per_sm * self.num_sm_per_tpc, 0x8000)
|
||||
self.shader_local_mem = self.allocator.alloc(round_up(bytes_per_tpc * self.num_tpc_per_gpc * self.num_gpcs, 0x20000))
|
||||
|
||||
NVComputeQueue().wait(self.timeline_signal, self.timeline_value - 1) \
|
||||
.setup(local_mem=self.shader_local_mem.va_addr, local_mem_tpc_bytes=bytes_per_tpc) \
|
||||
|
||||
Reference in New Issue
Block a user