mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
rdna4
This commit is contained in:
@@ -498,8 +498,8 @@ def _disasm_mimg(inst: MIMG) -> str:
|
||||
dim = dim_names[inst.dim] if inst.dim < len(dim_names) else f"dim_{inst.dim}"
|
||||
vaddr = _mimg_vaddr_width(name, inst.dim, inst.a16)
|
||||
vaddr_str = f"v{inst.vaddr}" if vaddr == 1 else _vreg(inst.vaddr, vaddr)
|
||||
# modifiers
|
||||
mods = [f"dmask:0x{inst.dmask:x}"] if inst.dmask and (inst.dmask != 15 or 'atomic' in name) else []
|
||||
# modifiers - always include dmask for image load/store/atomic (LLVM uses it for vdata size validation)
|
||||
mods = [f"dmask:0x{inst.dmask:x}"] if inst.dmask else []
|
||||
mods.append(f"dim:SQ_RSRC_IMG_{dim.upper()}")
|
||||
for flag, mod in [(inst.unrm,"unorm"),(inst.glc,"glc"),(inst.slc,"slc"),(inst.dlc,"dlc"),(inst.r128,"r128"),
|
||||
(inst.a16,"a16"),(inst.tfe,"tfe"),(inst.lwe,"lwe"),(inst.d16,"d16")]:
|
||||
@@ -625,9 +625,21 @@ def _disasm_vflat(inst) -> str:
|
||||
else: op_enum, prefix = VFLATOp, 'flat'
|
||||
name = op_enum(inst.op).name.lower()
|
||||
|
||||
# global_wb, global_wbinv, global_inv are cache control instructions with no operands
|
||||
if name in ('global_wb', 'global_wbinv', 'global_inv'):
|
||||
mods = _rdna4_mem_mods(inst.th, inst.scope, False, False)
|
||||
return f"{name}" + (f" {mods}" if mods else "")
|
||||
|
||||
# addtid instructions use thread ID as address offset, no vaddr operand
|
||||
is_addtid = 'addtid' in name
|
||||
|
||||
# Data width based on instruction name suffix
|
||||
suffix = name.split('_')[-1]
|
||||
base_w = {'b32':1,'b64':2,'b96':3,'b128':4,'u8':1,'i8':1,'u16':1,'i16':1,'u32':1,'i32':1,'u64':2,'i64':2,'f32':1,'f64':2}.get(suffix, 1)
|
||||
# block loads/stores use 32 VGPRs
|
||||
if 'block' in name:
|
||||
base_w = 32
|
||||
else:
|
||||
base_w = {'b32':1,'b64':2,'b96':3,'b128':4,'u8':1,'i8':1,'u16':1,'i16':1,'u32':1,'i32':1,'u64':2,'i64':2,'f32':1,'f64':2}.get(suffix, 1)
|
||||
# For cmpswap: vsrc holds cmp+data pairs (2x base), vdst is base width
|
||||
vsrc_w = base_w * 2 if 'cmpswap' in name else base_w
|
||||
vdst_w = base_w
|
||||
@@ -640,10 +652,10 @@ def _disasm_vflat(inst) -> str:
|
||||
is_store, is_atomic = 'store' in name, 'atomic' in name
|
||||
mods = _rdna4_mem_mods(inst.th, inst.scope, is_store, is_atomic)
|
||||
|
||||
# saddr handling - VGLOBAL needs explicit "off" when saddr=124
|
||||
if inst.saddr == 124: saddr_s = ", off" if prefix == 'global' else ""
|
||||
elif inst.saddr in SPECIAL_PAIRS: saddr_s = f", {SPECIAL_PAIRS[inst.saddr]}"
|
||||
else: saddr_s = f", {_sreg(inst.saddr, 2)}"
|
||||
# saddr handling - VGLOBAL and VSCRATCH need explicit "off" when saddr=124
|
||||
if inst.saddr == 124: saddr_s = "off" if prefix in ('global', 'scratch') else ""
|
||||
elif inst.saddr in SPECIAL_PAIRS: saddr_s = SPECIAL_PAIRS[inst.saddr]
|
||||
else: saddr_s = _sreg(inst.saddr, 2) if prefix == 'global' else decode_src(inst.saddr)
|
||||
|
||||
# Address width: 1 for scratch with saddr, 2 otherwise
|
||||
addr_w = 1 if (prefix == 'scratch' or (inst.saddr != 124 and prefix != 'flat')) else 2
|
||||
@@ -651,6 +663,14 @@ def _disasm_vflat(inst) -> str:
|
||||
vsrc_s = _vreg(inst.vsrc, vsrc_w)
|
||||
vdst_s = _vreg(inst.vdst, vdst_w)
|
||||
|
||||
# addtid instructions don't have vaddr, just vdata and saddr
|
||||
if is_addtid:
|
||||
if is_store: return f"{name} {vsrc_s}, {saddr_s}{off_s}" + (f" {mods}" if mods else "")
|
||||
return f"{name} {vdst_s}, {saddr_s}{off_s}" + (f" {mods}" if mods else "")
|
||||
|
||||
# Regular instructions need comma before saddr
|
||||
saddr_s = f", {saddr_s}" if saddr_s else ""
|
||||
|
||||
if is_atomic:
|
||||
if inst.th == 1: # TH_ATOMIC_RETURN
|
||||
return f"{name} {vdst_s}, {vaddr_s}, {vsrc_s}{saddr_s}{off_s}" + (f" {mods}" if mods else "")
|
||||
|
||||
@@ -13,7 +13,7 @@ RDNA3_TEST_FILES = {
|
||||
'vop3_from_vop1': 'gfx11_asm_vop3_from_vop1.s', 'vop3_from_vop2': 'gfx11_asm_vop3_from_vop2.s',
|
||||
'vop3_from_vopc': 'gfx11_asm_vop3_from_vopc.s', 'vop3_from_vopcx': 'gfx11_asm_vop3_from_vopcx.s',
|
||||
'ds': 'gfx11_asm_ds.s', 'smem': 'gfx11_asm_smem.s', 'flat': 'gfx11_asm_flat.s',
|
||||
'mubuf': 'gfx11_asm_mubuf.s', 'mtbuf': 'gfx11_asm_mtbuf.s', 'mimg': 'gfx11_asm_mimg.s', 'ldsdir': 'gfx11_asm_ldsdir.s',
|
||||
'mubuf': 'gfx11_asm_mubuf.s', 'mtbuf': 'gfx11_asm_mtbuf.s', 'mimg': 'gfx11_asm_mimg.s', 'mimg_features': 'gfx11_asm_mimg_features.s', 'ldsdir': 'gfx11_asm_ldsdir.s',
|
||||
'exp': 'gfx11_asm_exp.s', 'wmma': 'gfx11_asm_wmma.s',
|
||||
'vop3_features': 'gfx11_asm_vop3_features.s', 'vop3p_features': 'gfx11_asm_vop3p_features.s', 'vopd_features': 'gfx11_asm_vopd_features.s',
|
||||
'vop3_alias': 'gfx11_asm_vop3_alias.s', 'vop3p_alias': 'gfx11_asm_vop3p_alias.s', 'vopc_alias': 'gfx11_asm_vopc_alias.s',
|
||||
@@ -32,7 +32,8 @@ RDNA4_TEST_FILES = {
|
||||
'vop3_from_vop1': 'gfx12_asm_vop3_from_vop1.s', 'vop3_from_vop2': 'gfx12_asm_vop3_from_vop2.s',
|
||||
'ds': 'gfx12_asm_ds.s', 'ds_alias': 'gfx12_asm_ds_alias.s', 'smem': 'gfx12_asm_smem.s',
|
||||
'vflat': 'gfx12_asm_vflat.s', 'vflat_alias': 'gfx12_asm_vflat_alias.s',
|
||||
'vscratch': 'gfx12_asm_vflat.s', 'vscratch_alias': 'gfx12_asm_vflat_alias.s', # scratch instructions in vflat files
|
||||
'vglobal': 'gfx12_asm_vflat.s', 'vglobal_alias': 'gfx12_asm_vflat_alias.s', # global instructions in vflat files
|
||||
'vscratch': 'gfx12_asm_vflat.s', # scratch instructions in vflat file
|
||||
'vbuffer_mubuf': 'gfx12_asm_vbuffer_mubuf.s', 'vbuffer_mubuf_alias': 'gfx12_asm_vbuffer_mubuf_alias.s',
|
||||
'vbuffer_mtbuf': 'gfx12_asm_vbuffer_mtbuf.s', 'vbuffer_mtbuf_alias': 'gfx12_asm_vbuffer_mtbuf_alias.s',
|
||||
'vimage': 'gfx12_asm_vimage.s', 'vimage_alias': 'gfx12_asm_vimage_alias.s', 'vsample': 'gfx12_asm_vsample.s',
|
||||
@@ -157,7 +158,7 @@ class TestLLVMRDNA3(TestLLVMBase):
|
||||
'sop1': SOP1, 'sop2': SOP2, 'sopc': SOPC, 'sopk': SOPK, 'sopp': SOPP,
|
||||
'vop1': VOP1, 'vop2': VOP2, 'vopc': VOPC, 'vopcx': VOPC, 'vop3': VOP3, 'vop3p': VOP3P,
|
||||
'vinterp': VINTERP, 'vopd': VOPD, 'ds': DS, 'smem': SMEM, 'flat': FLAT,
|
||||
'mubuf': MUBUF, 'mtbuf': MTBUF, 'mimg': MIMG, 'wmma': VOP3P, 'ldsdir': LDSDIR, 'exp': EXP,
|
||||
'mubuf': MUBUF, 'mtbuf': MTBUF, 'mimg': MIMG, 'mimg_features': MIMG, 'wmma': VOP3P, 'ldsdir': LDSDIR, 'exp': EXP,
|
||||
'vop3_from_vop1': VOP3, 'vop3_from_vop2': VOP3, 'vop3_from_vopc': VOP3, 'vop3_from_vopcx': VOP3,
|
||||
'vop3_features': VOP3, 'vop3p_features': VOP3P, 'vopd_features': VOPD,
|
||||
'vop3_alias': VOP3, 'vop3p_alias': VOP3P, 'vopc_alias': VOPC, 'vopcx_alias': VOPC,
|
||||
@@ -186,7 +187,8 @@ class TestLLVMRDNA4(TestLLVMBase):
|
||||
'vbuffer_mtbuf': get('VBUFFER'), 'vbuffer_mtbuf_alias': get('VBUFFER'),
|
||||
'vdsdir': get('VDSDIR'), 'vdsdir_alias': get('VDSDIR'),
|
||||
'vflat': get('VFLAT'), 'vflat_alias': get('VFLAT'),
|
||||
'vscratch': get('VSCRATCH'), 'vscratch_alias': get('VSCRATCH'),
|
||||
'vglobal': get('VGLOBAL'), 'vglobal_alias': get('VGLOBAL'),
|
||||
'vscratch': get('VSCRATCH'),
|
||||
'vimage': get('VIMAGE'), 'vimage_alias': get('VIMAGE'), 'vsample': get('VSAMPLE'),
|
||||
'wmma_w32': get('VOP3P'), 'wmma_w64': get('VOP3P'),
|
||||
'global_load_tr': get('VGLOBAL'),
|
||||
|
||||
Reference in New Issue
Block a user