diff --git a/tinygrad/runtime/ops_kfd.py b/tinygrad/runtime/ops_kfd.py index a3f4c53b32..8e8a12c187 100644 --- a/tinygrad/runtime/ops_kfd.py +++ b/tinygrad/runtime/ops_kfd.py @@ -234,8 +234,9 @@ class KFDDevice(Compiled): self.amd_aql_queue.max_wave_id = self.properties['max_waves_per_simd'] * self.properties['simd_per_cu'] - 1 # scratch setup - self.max_private_segment_size = 256 - self.scratch_len = self.max_private_segment_size * (self.amd_aql_queue.max_cu_id + 1) * (self.amd_aql_queue.max_wave_id + 1) + self.max_private_segment_size = 512 + wave_scratch_len = round_up(((self.amd_aql_queue.max_wave_id + 1) * self.max_private_segment_size), 256) # gfx11 requires alignment of 256 + self.scratch_len = (self.amd_aql_queue.max_cu_id + 1) * self.properties['max_slots_scratch_cu'] * wave_scratch_len self.scratch = self._gpu_alloc(self.scratch_len, kfd.KFD_IOC_ALLOC_MEM_FLAGS_VRAM) self.amd_aql_queue.scratch_backing_memory_location = self.scratch.va_addr self.amd_aql_queue.scratch_backing_memory_byte_size = self.scratch_len @@ -244,9 +245,8 @@ class KFDDevice(Compiled): self.amd_aql_queue.scratch_resource_descriptor[1] = ((self.scratch.va_addr >> 32) & 0xFFFF) | (1 << 30) # va_hi | SWIZZLE_ENABLE self.amd_aql_queue.scratch_resource_descriptor[2] = self.scratch_len & 0xFFFFFFFF self.amd_aql_queue.scratch_resource_descriptor[3] = 0x20814fac # FORMAT=BUF_FORMAT_32_UINT,OOB_SELECT=2,ADD_TID_ENABLE=1,TYPE=SQ_RSRC_BUF,SQ_SELs - - wave_scratch = (((self.amd_aql_queue.max_wave_id + 1) * self.max_private_segment_size + 255) // 256) - self.amd_aql_queue.compute_tmpring_size = wave_scratch << 12 | (self.amd_aql_queue.max_cu_id + 1) + engines = self.properties['array_count'] // self.properties['simd_arrays_per_engine'] + self.amd_aql_queue.compute_tmpring_size = (wave_scratch_len // 256) << 12 | (self.scratch_len // (wave_scratch_len * engines)) self.aql_queue = kio.create_queue(KFDDevice.kfd, ring_base_address=self.aql_ring.va_addr, ring_size=self.aql_ring.size, gpu_id=self.gpu_id, queue_type=kfd.KFD_IOC_QUEUE_TYPE_COMPUTE_AQL, queue_percentage=kfd.KFD_MAX_QUEUE_PERCENTAGE, queue_priority=kfd.KFD_MAX_QUEUE_PRIORITY, @@ -292,8 +292,11 @@ class KFDDevice(Compiled): def _submit_sdma(self, dest, src, copy_size, wait_signals=None, completion_signal=None): def blit_sdma_command(cmd): - ctypes.memmove(self.sdma_ring.va_addr + (self.sdma_doorbell_value % self.sdma_ring.size), ctypes.addressof(cmd), sz:=ctypes.sizeof(cmd)) - self.sdma_doorbell_value += sz + if (cmdsz:=ctypes.sizeof(cmd)) > (fill:=self.sdma_ring.size - self.sdma_doorbell_value % self.sdma_ring.size): + ctypes.memset(self.sdma_ring.va_addr + (self.sdma_doorbell_value % self.sdma_ring.size), 0, fill) + self.sdma_doorbell_value += fill + ctypes.memmove(self.sdma_ring.va_addr + (self.sdma_doorbell_value % self.sdma_ring.size), ctypes.addressof(cmd), cmdsz) + self.sdma_doorbell_value += cmdsz if wait_signals is not None: # NOTE: we check only low 32 bits to be zeroed, we don't use higher values for signals