This commit is contained in:
George Hotz
2026-01-05 09:42:26 -08:00
parent a6c17e7081
commit 4e213cee95
2 changed files with 24 additions and 14 deletions

View File

@@ -33,7 +33,7 @@ def detect_format(data: bytes, arch: str = "rdna3") -> type[Inst]:
"""Detect instruction format from machine code bytes."""
assert len(data) >= 4, f"need at least 4 bytes, got {len(data)}"
word = int.from_bytes(data[:4], 'little')
if arch == "cdna":
if arch in ("cdna", "gfx90a", "gfx942"):
if (word >> 30) == 0b11:
for cls in _CDNA_FORMATS_64:
if _matches_encoding(word, cls):
@@ -675,17 +675,17 @@ _RDNA3_ONLY_ALIASES = {'v_mul_legacy_f32', 'v_fmac_legacy_f32', 'v_fma_legacy_f3
's_andn2_saveexec_b32', 's_andn2_saveexec_b64', 's_andn2_wrexec_b32', 's_andn2_wrexec_b64',
's_orn1_saveexec_b32', 's_orn1_saveexec_b64', 's_orn2_saveexec_b32', 's_orn2_saveexec_b64',
# VOP1: CDNA uses old names
'v_cvt_flr_i32_f32', 'v_cvt_rpi_i32_f32', 'v_ffbh_i32', 'v_ffbh_u32', 'v_ffbl_b32'}
'v_cvt_flr_i32_f32', 'v_cvt_rpi_i32_f32', 'v_ffbh_i32', 'v_ffbh_u32', 'v_ffbl_b32',
# VOPC: CDNA uses tru suffix for float comparisons
'v_cmp_tru_f16', 'v_cmp_tru_f32', 'v_cmp_tru_f64', 'v_cmpx_tru_f16', 'v_cmpx_tru_f32', 'v_cmpx_tru_f64'}
# CDNA-specific aliases (GFX9 uses different names for some instructions)
# CDNA-specific aliases - CDNA uses dword naming, not b32
_CDNA_ALIASES = {
# VOP aliases: madmk/madak -> fmamk/fmaak (same encoding, different name in CDNA enum)
'v_cvt_pkrtz_f16_f32': 'v_cvt_pk_rtz_f16_f32', 'v_madmk_f32': 'v_fmamk_f32', 'v_madak_f32': 'v_fmaak_f32',
# VOPC: v_cmp_t_fXX -> v_cmp_tru_fXX for CDNA
# VOPC: v_cmp_t_fXX -> v_cmp_tru_fXX for CDNA float comparisons
'v_cmp_t_f16': 'v_cmp_tru_f16', 'v_cmp_t_f32': 'v_cmp_tru_f32', 'v_cmp_t_f64': 'v_cmp_tru_f64',
'v_cmpx_t_f16': 'v_cmpx_tru_f16', 'v_cmpx_t_f32': 'v_cmpx_tru_f32', 'v_cmpx_t_f64': 'v_cmpx_tru_f64',
# VOP1: flr/rpi -> floor/nearest for CDNA
'v_cvt_flr_i32_f32': 'v_cvt_floor_i32_f32', 'v_cvt_rpi_i32_f32': 'v_cvt_nearest_i32_f32',
}
def _apply_alias(text: str, arch: str = "rdna3") -> str:
@@ -1420,11 +1420,12 @@ def get_dsl(text: str, arch: str = "rdna3", gfx942: bool = False) -> str:
a_str, kw_str = ', '.join(args), ', '.join(all_kw)
return f"{fn}({a_str}, {kw_str})" if kw_str and a_str else f"{fn}({kw_str})" if kw_str else f"{fn}({a_str})"
# CDNA VOP3A opcodes that need +64 offset (autogen has wrong values for GFX90a)
_CDNA_VOP3A_OPCODE_FIX = {'v_mul_legacy_f32': 64} # actual opcode should be 0x2a1, autogen has 0x261
# CDNA VOP3A opcodes that need +64 offset on gfx90a/gfx942 (autogen has wrong values)
_CDNA_VOP3A_OPCODE_FIX = {'v_mul_legacy_f32': 64} # gfx90a opcode is 0x2a1, autogen has 0x261
def _fix_cdna_opcode(inst, mnemonic: str):
"""Fix opcode for CDNA instructions where autogen has wrong values."""
def _fix_cdna_opcode(inst, mnemonic: str, is_gfx90a_or_942: bool):
"""Fix opcode for CDNA instructions where autogen has wrong values (gfx90a/gfx942 only)."""
if not is_gfx90a_or_942: return inst
base = mnemonic.removesuffix('_e64').removesuffix('_e32')
if base in _CDNA_VOP3A_OPCODE_FIX and hasattr(inst, '_values') and 'op' in inst._values:
op = inst._values['op']
@@ -1436,7 +1437,8 @@ def asm(text: str, arch: str = "rdna3") -> Inst:
# Normalize arch: gfx90a and gfx942 are CDNA variants
is_gfx942 = arch == "gfx942"
is_gfx90a = arch == "gfx90a"
if is_gfx942 or is_gfx90a: arch = "cdna"
is_gfx90a_or_942 = is_gfx942 or is_gfx90a
if is_gfx90a_or_942: arch = "cdna"
mnemonic = text.split()[0].lower()
dsl = get_dsl(text, arch, gfx942=is_gfx942)
if arch == "cdna":
@@ -1459,8 +1461,16 @@ def asm(text: str, arch: str = "rdna3") -> Inst:
ns = {n: getattr(ins, n) for n in dir(ins) if not n.startswith('_')}
ns.update({'s': s, 'v': v, 'ttmp': ttmp, 'abs': abs, 'RawImm': RawImm, 'SrcMod': SrcMod, 'VGPR': VGPR, 'SGPR': SGPR, 'TTMP': TTMP,
'VCC_LO': VCC_LO, 'VCC_HI': VCC_HI, 'VCC': VCC, 'EXEC_LO': EXEC_LO, 'EXEC_HI': EXEC_HI, 'EXEC': EXEC, 'SCC': SCC, 'M0': M0, 'NULL': NULL, 'OFF': OFF})
fix = (lambda inst: _fix_cdna_opcode(inst, mnemonic)) if arch == "cdna" else (lambda inst: inst)
fix = (lambda inst: _fix_cdna_opcode(inst, mnemonic, is_gfx90a_or_942)) if arch == "cdna" else (lambda inst: inst)
try:
# Generic CDNA (not gfx90a/gfx942): v_mul_legacy_f32 uses VOP2 opcode 4, _e64 uses VOP3A opcode 0x104
if arch == "cdna" and not is_gfx90a_or_942 and mnemonic.startswith('v_mul_legacy_f32'):
from extra.assembly.amd.autogen.cdna.ins import VOP2, VOP3A
args = _parse_ops(text[len(mnemonic):])
dsl_args = [_op2dsl(a, arch) for a in args]
if mnemonic == 'v_mul_legacy_f32_e64':
return eval(f"VOP3A(op=0x104, vdst={dsl_args[0]}, src0={dsl_args[1]}, src1={dsl_args[2]}, src2=RawImm(0))", ns)
return eval(f"VOP2(op=4, vdst={dsl_args[0]}, src0={dsl_args[1]}, vsrc1={dsl_args[2]})", ns)
# For CDNA, prefer _e32 variants for VOP1/VOP2 when available (bare names map to VOP3)
# But skip if:
# - already has _e64 suffix (explicit VOP3 request)

View File

@@ -49,7 +49,7 @@ def _get_tests(f: str, arch: str) -> list[tuple[str, bytes]]:
tests = _parse_llvm_tests(text, r'(?:GFX90A|GFX942)')
else:
tests = _parse_llvm_tests(text, r'(?:VI9|GFX9|CHECK)')
return [(a, d) for a, d in tests if not _is_mimg(d)] if arch == "cdna" else tests
return [(a, d) for a, d in tests if not _is_mimg(d)] if arch in ("cdna", "gfx90a", "gfx942") else tests
def _compile_asm_batch(instrs: list[str]) -> list[bytes]:
if not instrs: return []
@@ -100,8 +100,8 @@ for f in RDNA_FILES:
setattr(TestLLVM, f"test_rdna3_asm_{f.replace('.s', '').replace('-', '_')}", _make_test(f, "rdna3", "asm"))
setattr(TestLLVM, f"test_rdna3_disasm_{f.replace('.s', '').replace('-', '_')}", _make_test(f, "rdna3", "disasm"))
for f in CDNA_FILES:
# Use gfx942 arch for gfx942-specific files, cdna for others
arch = "gfx942" if "gfx942" in f else "cdna"
# Use specific arch for gfx90a/gfx942 files, generic cdna for others
arch = "gfx942" if "gfx942" in f else "gfx90a" if "gfx90a" in f else "cdna"
setattr(TestLLVM, f"test_cdna_roundtrip_{f.replace('.s', '').replace('-', '_')}", _make_test(f, arch, "roundtrip"))
setattr(TestLLVM, f"test_cdna_asm_{f.replace('.s', '').replace('-', '_')}", _make_test(f, arch, "asm"))
setattr(TestLLVM, f"test_cdna_disasm_{f.replace('.s', '').replace('-', '_')}", _make_test(f, arch, "disasm"))