This commit is contained in:
George Hotz
2026-01-01 21:11:29 -05:00
parent de29a49ea3
commit 267bbb163e
3 changed files with 68 additions and 16 deletions

View File

@@ -49,6 +49,7 @@ class InstOp(IntEnum):
VALU_64_SHIFT = 0xd # 64-bit shifts: lshl, lshr, ashr VALU_64_SHIFT = 0xd # 64-bit shifts: lshl, lshr, ashr
VALU_MAD64 = 0xe # 64-bit multiply-add VALU_MAD64 = 0xe # 64-bit multiply-add
VALU_64 = 0xf # 64-bit: add, mul, fma, rcp, sqrt, rounding, frexp, div helpers VALU_64 = 0xf # 64-bit: add, mul, fma, rcp, sqrt, rounding, frexp, div helpers
VINTERP = 0x12 # interpolation: v_interp_p10_f32, v_interp_p2_f32
# FLAT memory ops on traced SIMD (0x1x range) # FLAT memory ops on traced SIMD (0x1x range)
FLAT_LOAD = 0x1c FLAT_LOAD = 0x1c

View File

@@ -90,6 +90,12 @@ from extra.assembly.amd.autogen.rdna3.ins import (
v_dot2_f16_f16, v_dot2_f16_f16,
# WMMA # WMMA
v_wmma_f32_16x16x16_f16, v_wmma_f16_16x16x16_f16, v_wmma_i32_16x16x16_iu8, v_wmma_f32_16x16x16_f16, v_wmma_f16_16x16x16_f16, v_wmma_i32_16x16x16_iu8,
# Permlane ops
v_permlane64_b32_e32, v_permlane16_b32, v_permlanex16_b32,
# Interpolation
v_interp_p10_f32, v_interp_p2_f32,
# Barrier
s_barrier,
# SrcEnum for NULL soffset # SrcEnum for NULL soffset
SrcEnum, SrcEnum,
) )
@@ -97,7 +103,7 @@ from extra.assembly.amd.dsl import v, s
from extra.assembly.amd.sqtt import InstOp, INST, WAVESTART, WAVEEND, ALUEXEC, VMEMEXEC from extra.assembly.amd.sqtt import InstOp, INST, WAVESTART, WAVEEND, ALUEXEC, VMEMEXEC
from extra.assembly.amd.test.test_sqtt_hw import ( from extra.assembly.amd.test.test_sqtt_hw import (
run_asm_sqtt, decode_all_blobs, get_inst_ops, print_blobs, get_wave_packets, format_packet, PACKET_COLORS run_asm_sqtt, decode_all_blobs, get_inst_ops, print_blobs, get_wave_packets, format_packet, PACKET_COLORS, count_valuinst
) )
# ═══════════════════════════════════════════════════════════════════════════════ # ═══════════════════════════════════════════════════════════════════════════════
@@ -256,6 +262,20 @@ INSTRUCTION_TESTS: dict[str, tuple[str, list]] = {
"VALU_wmma_f16_f16": ("v_wmma_f16_16x16x16_f16", [v_wmma_f16_16x16x16_f16(v[0:7], v[8:15], v[16:23], v[0:7])]), "VALU_wmma_f16_f16": ("v_wmma_f16_16x16x16_f16", [v_wmma_f16_16x16x16_f16(v[0:7], v[8:15], v[16:23], v[0:7])]),
"VALU_wmma_i32_iu8": ("v_wmma_i32_16x16x16_iu8", [v_wmma_i32_16x16x16_iu8(v[0:7], v[8:11], v[12:15], v[0:7])]), "VALU_wmma_i32_iu8": ("v_wmma_i32_16x16x16_iu8", [v_wmma_i32_16x16x16_iu8(v[0:7], v[8:11], v[12:15], v[0:7])]),
# Permlane operations - cross-lane data movement
# NOTE: permlane64 produces NO SQTT packets in wave32 mode (it's for wave64 pairs)
# NOTE: permlane16/x16 produce VALUINST packets (no specific InstOp)
"VALU_permlane16": ("v_permlane16_b32", [v_permlane16_b32(v[0], v[1], s[2], s[3])]),
"VALU_permlanex16": ("v_permlanex16_b32", [v_permlanex16_b32(v[0], v[1], s[2], s[3])]),
# Interpolation - used in graphics shaders (produces InstOp 0x12 VINTERP)
"VINTERP_p10": ("v_interp_p10_f32", [v_interp_p10_f32(v[0], v[1], v[2], v[3])]),
"VINTERP_p2": ("v_interp_p2_f32", [v_interp_p2_f32(v[0], v[1], v[2], v[3])]),
# Barrier - wave synchronization
# NOTE: s_barrier produces NO SQTT instruction packets (with 1 wave, it's essentially a no-op)
"SALU_barrier": ("s_barrier", [s_barrier()]),
# LDS atomics # LDS atomics
"LDS_atomic_add": ("ds_add_u32", [ "LDS_atomic_add": ("ds_add_u32", [
v_mov_b32_e32(v[0], 0), # LDS address v_mov_b32_e32(v[0], 0), # LDS address
@@ -662,46 +682,53 @@ INSTRUCTION_TESTS: dict[str, tuple[str, list]] = {
} }
def run_with_retry(instructions: list, max_attempts: int = 20) -> tuple[list[tuple[int, list[bytes]]], list[list], set]: def run_with_retry(instructions: list, max_attempts: int = 20) -> tuple[list[tuple[int, list[bytes]]], list[list], set, int]:
"""Run instructions multiple times to collect InstOp variants. """Run instructions multiple times to collect InstOp variants.
Memory ops produce different InstOp values (0x2x vs 0x5x) depending on which SIMD executes them: Memory ops produce different InstOp values (0x2x vs 0x5x) depending on which SIMD executes them:
- 0x2x range: wave ran on traced SIMD (matched) - 0x2x range: wave ran on traced SIMD (matched)
- 0x5x range: wave ran on other SIMD (not matched) - 0x5x range: wave ran on other SIMD (not matched)
Returns list of (traced_simd, blobs) tuples. Returns list of (traced_simd, blobs) tuples, all_packets, all_ops, max_valuinst_count.
""" """
all_ops = set() all_ops = set()
all_runs: list[tuple[int, list[bytes]]] = [] all_runs: list[tuple[int, list[bytes]]] = []
all_packets = [] all_packets = []
max_valuinst = 0
SQTT_SIMD_SEL.value = 0 # only trace SIMD 0 SQTT_SIMD_SEL.value = 0 # only trace SIMD 0
for _ in range(max_attempts): for _ in range(max_attempts):
blobs = run_asm_sqtt(instructions) blobs = run_asm_sqtt(instructions)
packets = decode_all_blobs(blobs) packets = decode_all_blobs(blobs)
# get ops from waves on traced SIMD 0 (gives 0x2x range) # get ops and valuinst from all SIMDs
ops = get_inst_ops(packets, traced_simd=0) ops = set()
# also get ops from waves on other SIMDs (gives 0x5x range for memory ops) valuinst_count = 0
for simd in [1, 2, 3]: for simd in [0, 1, 2, 3]:
ops.update(get_inst_ops(packets, traced_simd=simd)) ops.update(get_inst_ops(packets, traced_simd=simd))
valuinst_count = max(valuinst_count, count_valuinst(packets, traced_simd=simd))
all_runs.append((0, blobs)) all_runs.append((0, blobs))
all_packets.append(packets) all_packets.append(packets)
all_ops.update(ops) all_ops.update(ops)
return all_runs, all_packets, all_ops max_valuinst = max(max_valuinst, valuinst_count)
return all_runs, all_packets, all_ops, max_valuinst
def discover_all_instops() -> tuple[dict[int, set[str]], dict[str, Exception]]: def discover_all_instops() -> tuple[dict[int, set[str]], dict[str, Exception], dict[str, int]]:
"""Run all instruction tests and collect InstOp values.""" """Run all instruction tests and collect InstOp values."""
discovered: dict[int, set[str]] = {} discovered: dict[int, set[str]] = {}
failures: dict[str, Exception] = {} failures: dict[str, Exception] = {}
valuinst_tests: dict[str, int] = {} # tests that produced VALUINST packets
for test_name, (instr_name, instructions) in INSTRUCTION_TESTS.items(): for test_name, (instr_name, instructions) in INSTRUCTION_TESTS.items():
try: try:
all_runs, _, ops = run_with_retry(instructions) all_runs, _, ops, valuinst_count = run_with_retry(instructions)
for op in ops: for op in ops:
if op not in discovered: if op not in discovered:
discovered[op] = set() discovered[op] = set()
discovered[op].add(f"{test_name}") discovered[op].add(f"{test_name}")
if valuinst_count > 0:
valuinst_tests[test_name] = valuinst_count
if DEBUG >= 2: if DEBUG >= 2:
print(f"\n{''*60}") print(f"\n{''*60}")
print(f"{test_name} ({instr_name}): ops={[hex(op) for op in sorted(ops)]}") print(f"{test_name} ({instr_name}): ops={[hex(op) for op in sorted(ops)]}")
@@ -741,19 +768,20 @@ def discover_all_instops() -> tuple[dict[int, set[str]], dict[str, Exception]]:
print(f"\n=== traced simd={traced_simd} ===") print(f"\n=== traced simd={traced_simd} ===")
print_blobs(blobs, wave_only=False) print_blobs(blobs, wave_only=False)
if DEBUG >= 1: if DEBUG >= 1:
status = colored("", "green") if ops else colored("", "yellow") status = colored("", "green") if ops else (colored("V", "cyan") if valuinst_count > 0 else colored("", "yellow"))
ops_str = ", ".join(hex(op) for op in sorted(ops)) if ops else "none" ops_str = ", ".join(hex(op) for op in sorted(ops)) if ops else "none"
print(f" {status} {test_name:25s} ops=[{ops_str}]") valuinst_str = f" valuinst={valuinst_count}" if valuinst_count > 0 and not ops else ""
print(f" {status} {test_name:25s} ops=[{ops_str}]{valuinst_str}")
except Exception as e: except Exception as e:
failures[test_name] = e failures[test_name] = e
if DEBUG >= 1: if DEBUG >= 1:
print(f" {colored('', 'red')} {test_name:25s} FAILED: {e}") print(f" {colored('', 'red')} {test_name:25s} FAILED: {e}")
return discovered, failures return discovered, failures, valuinst_tests
def print_summary(discovered: dict[int, set[str]], failures: dict[str, Exception]) -> None: def print_summary(discovered: dict[int, set[str]], failures: dict[str, Exception], valuinst_tests: dict[str, int]) -> None:
"""Print discovery summary.""" """Print discovery summary."""
known_ops = {e.value for e in InstOp} known_ops = {e.value for e in InstOp}
discovered_ops = set(discovered.keys()) discovered_ops = set(discovered.keys())
@@ -773,6 +801,15 @@ def print_summary(discovered: dict[int, set[str]], failures: dict[str, Exception
sources = ", ".join(sorted(discovered[op])) sources = ", ".join(sorted(discovered[op]))
print(f" 0x{op:02x} {name:20s} ({status}) <- {sources}") print(f" 0x{op:02x} {name:20s} ({status}) <- {sources}")
# VALUINST tests (instructions that only produce VALUINST, not INST packets)
valuinst_only = {k: v for k, v in valuinst_tests.items() if not any(k in tests for tests in discovered.values())}
if valuinst_only:
print("\n" + "=" * 60)
print(colored("VALUINST-ONLY INSTRUCTIONS (no InstOp, use VALUINST packet)", "cyan"))
print("=" * 60)
for test_name, count in sorted(valuinst_only.items()):
print(f" {test_name}: {count} VALUINST packets")
# Missing from enum # Missing from enum
missing = known_ops - discovered_ops missing = known_ops - discovered_ops
if missing: if missing:
@@ -805,6 +842,7 @@ def print_summary(discovered: dict[int, set[str]], failures: dict[str, Exception
if known_ops: if known_ops:
print(f" Coverage: {len(discovered_ops & known_ops)}/{len(known_ops)} ({100*len(discovered_ops & known_ops)//len(known_ops)}%)") print(f" Coverage: {len(discovered_ops & known_ops)}/{len(known_ops)} ({100*len(discovered_ops & known_ops)//len(known_ops)}%)")
print(f" New ops found: {len(new_ops)}") print(f" New ops found: {len(new_ops)}")
print(f" VALUINST-only: {len(valuinst_only)}")
if __name__ == "__main__": if __name__ == "__main__":
@@ -813,5 +851,5 @@ if __name__ == "__main__":
print("=" * 60) print("=" * 60)
print(f"Testing {len(INSTRUCTION_TESTS)} instruction categories...\n") print(f"Testing {len(INSTRUCTION_TESTS)} instruction categories...\n")
discovered, failures = discover_all_instops() discovered, failures, valuinst_tests = discover_all_instops()
print_summary(discovered, failures) print_summary(discovered, failures, valuinst_tests)

View File

@@ -205,6 +205,19 @@ def get_inst_ops(packets: list, traced_simd: int | None = None) -> set:
in_wave = False in_wave = False
return ops return ops
def count_valuinst(packets: list, traced_simd: int | None = None) -> int:
"""Count VALUINST packets within WAVESTART..WAVEEND on traced SIMD."""
count = 0
in_wave = False
for p in packets:
if isinstance(p, WAVESTART):
in_wave = traced_simd is None or p.simd == traced_simd
if in_wave and isinstance(p, VALUINST):
count += 1
if isinstance(p, WAVEEND):
in_wave = False
return count
# ═══════════════════════════════════════════════════════════════════════════════ # ═══════════════════════════════════════════════════════════════════════════════
# TESTS # TESTS
# ═══════════════════════════════════════════════════════════════════════════════ # ═══════════════════════════════════════════════════════════════════════════════