mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Revert "amd: use correct structs (#13583)"
This reverts commit d8b09eda57.
This commit is contained in:
@@ -998,23 +998,19 @@ class AMDDevice(HCQCompiled):
|
||||
waves = wave_scratch_len // (256 if self.target >= (11,0,0) else 1024)
|
||||
# >=gfx11 wavesize is per SE
|
||||
wavesize = scratch_size // ((wave_scratch_len * self.se_cnt) if self.target >= (11,0,0) else wave_scratch_len)
|
||||
|
||||
tmpring_t = getattr(hsa, f'union_COMPUTE_TMPRING_SIZE{"_GFX"+str(self.target[0]) if self.target[0] >= 11 else ""}_bitfields')
|
||||
self.tmpring_size = int.from_bytes(tmpring_t(waves=waves, wavesize=wavesize), 'little')
|
||||
self.tmpring_size = waves << 12 | wavesize
|
||||
self.max_private_segment_size = required
|
||||
|
||||
if hasattr(self, 'aql_desc'):
|
||||
gfx9_rsrc = {'NUM_FORMAT':hsa.BUF_NUM_FORMAT_UINT, 'DATA_FORMAT':hsa.BUF_DATA_FORMAT_32, 'ELEMENT_SIZE':1, 'INDEX_STRIDE':3}
|
||||
rsrc = {'DST_SEL_X':hsa.SQ_SEL_X, 'DST_SEL_Y':hsa.SQ_SEL_Y, 'DST_SEL_Z':hsa.SQ_SEL_Z, 'DST_SEL_W':hsa.SQ_SEL_W, 'ADD_TID_ENABLE':1,
|
||||
'TYPE':hsa.SQ_RSRC_BUF, **(gfx9_rsrc if self.target[0] < 10 else {'FORMAT':hsa.BUF_FORMAT_32_UINT, 'OOB_SELECT':2})}
|
||||
rsrc1_t = getattr(hsa, f'union_SQ_BUF_RSRC_WORD1{"_GFX11" if self.target[0] >= 11 else ""}_bitfields')
|
||||
rsrc3_t = getattr(hsa, f'union_SQ_BUF_RSRC_WORD3{"_GFX"+str(self.target[0]) if self.target[0] >= 10 else ""}_bitfields')
|
||||
rsrc_t = getattr(hsa, f'union_SQ_BUF_RSRC_WORD3{"_GFX"+str(self.target[0]) if self.target[0] >= 10 else ""}_bitfields')
|
||||
|
||||
self.aql_desc.scratch_backing_memory_location = self.scratch.va_addr
|
||||
self.aql_desc.scratch_wave64_lane_byte_size = self.max_private_segment_size * (self.aql_desc.max_wave_id + 1) // 64
|
||||
self.aql_desc.scratch_resource_descriptor[:] = [lo32(self.scratch.va_addr),
|
||||
int.from_bytes(rsrc1_t(base_address_hi=hi32(self.scratch.va_addr), swizzle_enable=1), 'little'),
|
||||
lo32(scratch_size), int.from_bytes(bytes(rsrc3_t(**rsrc)), 'little')]
|
||||
self.aql_desc.scratch_resource_descriptor[:] = [lo32(self.scratch.va_addr), hi32(self.scratch.va_addr) | (1 << 30), lo32(scratch_size),
|
||||
int.from_bytes(bytes(rsrc_t(**rsrc)), 'little')]
|
||||
self.aql_desc.compute_tmpring_size = self.tmpring_size
|
||||
|
||||
def invalidate_caches(self):
|
||||
|
||||
Reference in New Issue
Block a user