mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-30 09:18:07 -05:00
minor cleanups (#4802)
This commit is contained in:
@@ -105,7 +105,7 @@ class HWPM4Queue:
|
||||
self.q += [amd_gpu.PACKET3(amd_gpu.PACKET3_SET_SH_REG, 6), gfxreg(amd_gpu.regCOMPUTE_PGM_LO), (prg.prog_addr>>8) & 0xFFFFFFFF,
|
||||
prg.prog_addr >> 40, 0, 0, (prg.device.scratch.va_addr>>8) & 0xFFFFFFFF, prg.device.scratch.va_addr >> 40]
|
||||
self.q += [amd_gpu.PACKET3(amd_gpu.PACKET3_SET_SH_REG, 2), gfxreg(amd_gpu.regCOMPUTE_PGM_RSRC1), prg.rsrc1, prg.rsrc2]
|
||||
self.q += [amd_gpu.PACKET3(amd_gpu.PACKET3_SET_SH_REG, 1), gfxreg(amd_gpu.regCOMPUTE_TMPRING_SIZE), 0x00200200] # (waveSize << 12) | (numWaves)
|
||||
self.q += [amd_gpu.PACKET3(amd_gpu.PACKET3_SET_SH_REG, 1), gfxreg(amd_gpu.regCOMPUTE_TMPRING_SIZE), prg.device.tmpring_size]
|
||||
self.q += [amd_gpu.PACKET3(amd_gpu.PACKET3_SET_SH_REG, 4), gfxreg(amd_gpu.regCOMPUTE_RESTART_X), 0, 0, 0, 0]
|
||||
self.q += [amd_gpu.PACKET3(amd_gpu.PACKET3_SET_SH_REG, 2), gfxreg(amd_gpu.regCOMPUTE_STATIC_THREAD_MGMT_SE0)] + [0xFFFFFFFF] * 2
|
||||
self.q += [amd_gpu.PACKET3(amd_gpu.PACKET3_SET_SH_REG, 2), gfxreg(amd_gpu.regCOMPUTE_STATIC_THREAD_MGMT_SE2)] + [0xFFFFFFFF] * 2
|
||||
@@ -293,12 +293,13 @@ class AMDProgram:
|
||||
def __call__(self, *args, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False):
|
||||
if self.device.kernargs_ptr + self.kernargs_segment_size > (self.device.kernargs.va_addr + self.device.kernargs.size):
|
||||
self.device.kernargs_ptr = self.device.kernargs.va_addr
|
||||
assert self.device.kernargs_ptr + self.kernargs_segment_size <= (self.device.kernargs.va_addr + self.device.kernargs.size), "kernargs overrun"
|
||||
|
||||
if not hasattr(self, "args_struct_t"):
|
||||
self.args_struct_t = init_c_struct_t(tuple([(f'f{i}', ctypes.c_void_p) for i in range(len(args))] +
|
||||
[(f'v{i}', ctypes.c_int) for i in range(len(vals))]))
|
||||
if ctypes.sizeof(self.args_struct_t) != self.kernargs_segment_size:
|
||||
raise RuntimeError(f"HSAProgram.__call__: incorrect args struct size {ctypes.sizeof(self.args_struct_t)} != {self.kernargs_segment_size}")
|
||||
raise RuntimeError(f"AMDProgram.__call__: incorrect args struct size {ctypes.sizeof(self.args_struct_t)} != {self.kernargs_segment_size}")
|
||||
|
||||
args_st = self.args_struct_t.from_address(self.device.kernargs_ptr)
|
||||
for i in range(len(args)): args_st.__setattr__(f'f{i}', args[i].va_addr)
|
||||
for i in range(len(vals)): args_st.__setattr__(f'v{i}', vals[i])
|
||||
@@ -486,6 +487,8 @@ class AMDDevice(Compiled):
|
||||
wave_scratch_len = round_up(((max_wave_id + 1) * self.max_private_segment_size), 256) # gfx11 requires alignment of 256
|
||||
self.scratch_len = (max_cu_id + 1) * self.properties['max_slots_scratch_cu'] * wave_scratch_len
|
||||
self.scratch = self._gpu_alloc(self.scratch_len, kfd.KFD_IOC_ALLOC_MEM_FLAGS_VRAM)
|
||||
engines = self.properties['array_count'] // self.properties['simd_arrays_per_engine']
|
||||
self.tmpring_size = (wave_scratch_len // 256) << 12 | (self.scratch_len // (wave_scratch_len * engines))
|
||||
|
||||
# SDMA Queue
|
||||
self.sdma_gart = self._gpu_alloc(0x1000, kfd.KFD_IOC_ALLOC_MEM_FLAGS_GTT, uncached=True)
|
||||
|
||||
Reference in New Issue
Block a user