Revert "amd: use correct structs (#13583)"

This reverts commit d8b09eda57.
This commit is contained in:
chenyu
2025-12-05 09:24:02 -05:00
committed by GitHub
parent 8c332219f9
commit dec2ea8a28

View File

@@ -998,23 +998,19 @@ class AMDDevice(HCQCompiled):
waves = wave_scratch_len // (256 if self.target >= (11,0,0) else 1024)
# >=gfx11 wavesize is per SE
wavesize = scratch_size // ((wave_scratch_len * self.se_cnt) if self.target >= (11,0,0) else wave_scratch_len)
tmpring_t = getattr(hsa, f'union_COMPUTE_TMPRING_SIZE{"_GFX"+str(self.target[0]) if self.target[0] >= 11 else ""}_bitfields')
self.tmpring_size = int.from_bytes(tmpring_t(waves=waves, wavesize=wavesize), 'little')
self.tmpring_size = waves << 12 | wavesize
self.max_private_segment_size = required
if hasattr(self, 'aql_desc'):
gfx9_rsrc = {'NUM_FORMAT':hsa.BUF_NUM_FORMAT_UINT, 'DATA_FORMAT':hsa.BUF_DATA_FORMAT_32, 'ELEMENT_SIZE':1, 'INDEX_STRIDE':3}
rsrc = {'DST_SEL_X':hsa.SQ_SEL_X, 'DST_SEL_Y':hsa.SQ_SEL_Y, 'DST_SEL_Z':hsa.SQ_SEL_Z, 'DST_SEL_W':hsa.SQ_SEL_W, 'ADD_TID_ENABLE':1,
'TYPE':hsa.SQ_RSRC_BUF, **(gfx9_rsrc if self.target[0] < 10 else {'FORMAT':hsa.BUF_FORMAT_32_UINT, 'OOB_SELECT':2})}
rsrc1_t = getattr(hsa, f'union_SQ_BUF_RSRC_WORD1{"_GFX11" if self.target[0] >= 11 else ""}_bitfields')
rsrc3_t = getattr(hsa, f'union_SQ_BUF_RSRC_WORD3{"_GFX"+str(self.target[0]) if self.target[0] >= 10 else ""}_bitfields')
rsrc_t = getattr(hsa, f'union_SQ_BUF_RSRC_WORD3{"_GFX"+str(self.target[0]) if self.target[0] >= 10 else ""}_bitfields')
self.aql_desc.scratch_backing_memory_location = self.scratch.va_addr
self.aql_desc.scratch_wave64_lane_byte_size = self.max_private_segment_size * (self.aql_desc.max_wave_id + 1) // 64
self.aql_desc.scratch_resource_descriptor[:] = [lo32(self.scratch.va_addr),
int.from_bytes(rsrc1_t(base_address_hi=hi32(self.scratch.va_addr), swizzle_enable=1), 'little'),
lo32(scratch_size), int.from_bytes(bytes(rsrc3_t(**rsrc)), 'little')]
self.aql_desc.scratch_resource_descriptor[:] = [lo32(self.scratch.va_addr), hi32(self.scratch.va_addr) | (1 << 30), lo32(scratch_size),
int.from_bytes(bytes(rsrc_t(**rsrc)), 'little')]
self.aql_desc.compute_tmpring_size = self.tmpring_size
def invalidate_caches(self):