From 1b18ebb133f9cb4c2bb33a800529f9ca070cef63 Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Sat, 1 Jun 2024 20:11:43 +0300 Subject: [PATCH] minor cleanups (#4802) --- tinygrad/runtime/ops_amd.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tinygrad/runtime/ops_amd.py b/tinygrad/runtime/ops_amd.py index 536a2f87f4..720c175380 100644 --- a/tinygrad/runtime/ops_amd.py +++ b/tinygrad/runtime/ops_amd.py @@ -105,7 +105,7 @@ class HWPM4Queue: self.q += [amd_gpu.PACKET3(amd_gpu.PACKET3_SET_SH_REG, 6), gfxreg(amd_gpu.regCOMPUTE_PGM_LO), (prg.prog_addr>>8) & 0xFFFFFFFF, prg.prog_addr >> 40, 0, 0, (prg.device.scratch.va_addr>>8) & 0xFFFFFFFF, prg.device.scratch.va_addr >> 40] self.q += [amd_gpu.PACKET3(amd_gpu.PACKET3_SET_SH_REG, 2), gfxreg(amd_gpu.regCOMPUTE_PGM_RSRC1), prg.rsrc1, prg.rsrc2] - self.q += [amd_gpu.PACKET3(amd_gpu.PACKET3_SET_SH_REG, 1), gfxreg(amd_gpu.regCOMPUTE_TMPRING_SIZE), 0x00200200] # (waveSize << 12) | (numWaves) + self.q += [amd_gpu.PACKET3(amd_gpu.PACKET3_SET_SH_REG, 1), gfxreg(amd_gpu.regCOMPUTE_TMPRING_SIZE), prg.device.tmpring_size] self.q += [amd_gpu.PACKET3(amd_gpu.PACKET3_SET_SH_REG, 4), gfxreg(amd_gpu.regCOMPUTE_RESTART_X), 0, 0, 0, 0] self.q += [amd_gpu.PACKET3(amd_gpu.PACKET3_SET_SH_REG, 2), gfxreg(amd_gpu.regCOMPUTE_STATIC_THREAD_MGMT_SE0)] + [0xFFFFFFFF] * 2 self.q += [amd_gpu.PACKET3(amd_gpu.PACKET3_SET_SH_REG, 2), gfxreg(amd_gpu.regCOMPUTE_STATIC_THREAD_MGMT_SE2)] + [0xFFFFFFFF] * 2 @@ -293,12 +293,13 @@ class AMDProgram: def __call__(self, *args, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False): if self.device.kernargs_ptr + self.kernargs_segment_size > (self.device.kernargs.va_addr + self.device.kernargs.size): self.device.kernargs_ptr = self.device.kernargs.va_addr - assert self.device.kernargs_ptr + self.kernargs_segment_size <= (self.device.kernargs.va_addr + self.device.kernargs.size), "kernargs overrun" + if not hasattr(self, "args_struct_t"): self.args_struct_t = init_c_struct_t(tuple([(f'f{i}', ctypes.c_void_p) for i in range(len(args))] + [(f'v{i}', ctypes.c_int) for i in range(len(vals))])) if ctypes.sizeof(self.args_struct_t) != self.kernargs_segment_size: - raise RuntimeError(f"HSAProgram.__call__: incorrect args struct size {ctypes.sizeof(self.args_struct_t)} != {self.kernargs_segment_size}") + raise RuntimeError(f"AMDProgram.__call__: incorrect args struct size {ctypes.sizeof(self.args_struct_t)} != {self.kernargs_segment_size}") + args_st = self.args_struct_t.from_address(self.device.kernargs_ptr) for i in range(len(args)): args_st.__setattr__(f'f{i}', args[i].va_addr) for i in range(len(vals)): args_st.__setattr__(f'v{i}', vals[i]) @@ -486,6 +487,8 @@ class AMDDevice(Compiled): wave_scratch_len = round_up(((max_wave_id + 1) * self.max_private_segment_size), 256) # gfx11 requires alignment of 256 self.scratch_len = (max_cu_id + 1) * self.properties['max_slots_scratch_cu'] * wave_scratch_len self.scratch = self._gpu_alloc(self.scratch_len, kfd.KFD_IOC_ALLOC_MEM_FLAGS_VRAM) + engines = self.properties['array_count'] // self.properties['simd_arrays_per_engine'] + self.tmpring_size = (wave_scratch_len // 256) << 12 | (self.scratch_len // (wave_scratch_len * engines)) # SDMA Queue self.sdma_gart = self._gpu_alloc(0x1000, kfd.KFD_IOC_ALLOC_MEM_FLAGS_GTT, uncached=True)