diff --git a/tinygrad/runtime/ops_amd.py b/tinygrad/runtime/ops_amd.py index 043818dd6f..62c53c506b 100644 --- a/tinygrad/runtime/ops_amd.py +++ b/tinygrad/runtime/ops_amd.py @@ -82,7 +82,12 @@ class AMDComputeQueue(HWQueue): def exec(self, prg:AMDProgram, args_state:AMDArgsState, global_size:Tuple[sint, ...], local_size:Tuple[sint, ...]): self.acquire_mem(gli=0, gl2=0) - user_regs = [*data64_le(prg.dev.scratch.va_addr), 0xffffffff, 0xc00000] if prg.enable_private_segment_sgpr else [] + if prg.enable_private_segment_sgpr: + scratch_hilo = data64_le(prg.dev.scratch.va_addr) + # sgpr word1 bit31 enables swizzle + # sgpr word3 = 0x14 << 12 | 2 << 28 | 2 << 21 | 1 << 23 + user_regs = [scratch_hilo[0], scratch_hilo[1] | 1 << 31, 0xffffffff, 0x20c14000] if prg.enable_private_segment_sgpr else [] + else: user_regs = [] if prg.enable_dispatch_ptr: dp = hsa.hsa_kernel_dispatch_packet_t.from_address(dp_addr:=args_state.ptr + prg.kernargs_segment_size) @@ -370,12 +375,16 @@ class AMDDevice(HCQCompiled): max_cu_id = self.properties['simd_count'] // self.properties['simd_per_cu'] - 1 max_wave_id = self.properties['max_waves_per_simd'] * self.properties['simd_per_cu'] - 1 self.max_private_segment_size = 4096 - wave_scratch_len = round_up(((max_wave_id + 1) * self.max_private_segment_size), 256) # gfx11 requires alignment of 256 + # =gfx11 requires 256 + wave_scratch_len = round_up(((max_wave_id + 1) * self.max_private_segment_size), 256 if self.target >= 110000 else 1024) self.scratch_len = (max_cu_id + 1) * self.properties['max_slots_scratch_cu'] * wave_scratch_len self.scratch = self._gpu_alloc(self.scratch_len, kfd.KFD_IOC_ALLOC_MEM_FLAGS_VRAM) self.has_scratch_base_registers = self.target >= 110000 engines = self.properties['array_count'] // self.properties['simd_arrays_per_engine'] - self.tmpring_size = (wave_scratch_len // 256) << 12 | (self.scratch_len // (wave_scratch_len * engines)) + waves = wave_scratch_len // (256 if self.target >= 110000 else 1024) + # >=gfx11 wavesize is per SE + wavesize = self.scratch_len // ((wave_scratch_len * engines) if self.target >= 110000 else wave_scratch_len) + self.tmpring_size = waves << 12 | wavesize # https://gitlab.freedesktop.org/agd5f/linux/-/blob/a1fc9f584c4aaf8bc1ebfa459fc57a3f26a290d8/drivers/gpu/drm/amd/amdkfd/kfd_queue.c#L391 sgrp_size_per_cu, lds_size_per_cu, hwreg_size_per_cu = 0x4000, 0x10000, 0x1000