This commit is contained in:
George Hotz
2026-01-01 21:11:29 -05:00
parent de29a49ea3
commit 267bbb163e
3 changed files with 68 additions and 16 deletions

View File

@@ -49,6 +49,7 @@ class InstOp(IntEnum):
VALU_64_SHIFT = 0xd # 64-bit shifts: lshl, lshr, ashr
VALU_MAD64 = 0xe # 64-bit multiply-add
VALU_64 = 0xf # 64-bit: add, mul, fma, rcp, sqrt, rounding, frexp, div helpers
VINTERP = 0x12 # interpolation: v_interp_p10_f32, v_interp_p2_f32
# FLAT memory ops on traced SIMD (0x1x range)
FLAT_LOAD = 0x1c

View File

@@ -90,6 +90,12 @@ from extra.assembly.amd.autogen.rdna3.ins import (
v_dot2_f16_f16,
# WMMA
v_wmma_f32_16x16x16_f16, v_wmma_f16_16x16x16_f16, v_wmma_i32_16x16x16_iu8,
# Permlane ops
v_permlane64_b32_e32, v_permlane16_b32, v_permlanex16_b32,
# Interpolation
v_interp_p10_f32, v_interp_p2_f32,
# Barrier
s_barrier,
# SrcEnum for NULL soffset
SrcEnum,
)
@@ -97,7 +103,7 @@ from extra.assembly.amd.dsl import v, s
from extra.assembly.amd.sqtt import InstOp, INST, WAVESTART, WAVEEND, ALUEXEC, VMEMEXEC
from extra.assembly.amd.test.test_sqtt_hw import (
run_asm_sqtt, decode_all_blobs, get_inst_ops, print_blobs, get_wave_packets, format_packet, PACKET_COLORS
run_asm_sqtt, decode_all_blobs, get_inst_ops, print_blobs, get_wave_packets, format_packet, PACKET_COLORS, count_valuinst
)
# ═══════════════════════════════════════════════════════════════════════════════
@@ -256,6 +262,20 @@ INSTRUCTION_TESTS: dict[str, tuple[str, list]] = {
"VALU_wmma_f16_f16": ("v_wmma_f16_16x16x16_f16", [v_wmma_f16_16x16x16_f16(v[0:7], v[8:15], v[16:23], v[0:7])]),
"VALU_wmma_i32_iu8": ("v_wmma_i32_16x16x16_iu8", [v_wmma_i32_16x16x16_iu8(v[0:7], v[8:11], v[12:15], v[0:7])]),
# Permlane operations - cross-lane data movement
# NOTE: permlane64 produces NO SQTT packets in wave32 mode (it's for wave64 pairs)
# NOTE: permlane16/x16 produce VALUINST packets (no specific InstOp)
"VALU_permlane16": ("v_permlane16_b32", [v_permlane16_b32(v[0], v[1], s[2], s[3])]),
"VALU_permlanex16": ("v_permlanex16_b32", [v_permlanex16_b32(v[0], v[1], s[2], s[3])]),
# Interpolation - used in graphics shaders (produces InstOp 0x12 VINTERP)
"VINTERP_p10": ("v_interp_p10_f32", [v_interp_p10_f32(v[0], v[1], v[2], v[3])]),
"VINTERP_p2": ("v_interp_p2_f32", [v_interp_p2_f32(v[0], v[1], v[2], v[3])]),
# Barrier - wave synchronization
# NOTE: s_barrier produces NO SQTT instruction packets (with 1 wave, it's essentially a no-op)
"SALU_barrier": ("s_barrier", [s_barrier()]),
# LDS atomics
"LDS_atomic_add": ("ds_add_u32", [
v_mov_b32_e32(v[0], 0), # LDS address
@@ -662,46 +682,53 @@ INSTRUCTION_TESTS: dict[str, tuple[str, list]] = {
}
def run_with_retry(instructions: list, max_attempts: int = 20) -> tuple[list[tuple[int, list[bytes]]], list[list], set]:
def run_with_retry(instructions: list, max_attempts: int = 20) -> tuple[list[tuple[int, list[bytes]]], list[list], set, int]:
"""Run instructions multiple times to collect InstOp variants.
Memory ops produce different InstOp values (0x2x vs 0x5x) depending on which SIMD executes them:
- 0x2x range: wave ran on traced SIMD (matched)
- 0x5x range: wave ran on other SIMD (not matched)
Returns list of (traced_simd, blobs) tuples.
Returns list of (traced_simd, blobs) tuples, all_packets, all_ops, max_valuinst_count.
"""
all_ops = set()
all_runs: list[tuple[int, list[bytes]]] = []
all_packets = []
max_valuinst = 0
SQTT_SIMD_SEL.value = 0 # only trace SIMD 0
for _ in range(max_attempts):
blobs = run_asm_sqtt(instructions)
packets = decode_all_blobs(blobs)
# get ops from waves on traced SIMD 0 (gives 0x2x range)
ops = get_inst_ops(packets, traced_simd=0)
# also get ops from waves on other SIMDs (gives 0x5x range for memory ops)
for simd in [1, 2, 3]:
# get ops and valuinst from all SIMDs
ops = set()
valuinst_count = 0
for simd in [0, 1, 2, 3]:
ops.update(get_inst_ops(packets, traced_simd=simd))
valuinst_count = max(valuinst_count, count_valuinst(packets, traced_simd=simd))
all_runs.append((0, blobs))
all_packets.append(packets)
all_ops.update(ops)
return all_runs, all_packets, all_ops
max_valuinst = max(max_valuinst, valuinst_count)
return all_runs, all_packets, all_ops, max_valuinst
def discover_all_instops() -> tuple[dict[int, set[str]], dict[str, Exception]]:
def discover_all_instops() -> tuple[dict[int, set[str]], dict[str, Exception], dict[str, int]]:
"""Run all instruction tests and collect InstOp values."""
discovered: dict[int, set[str]] = {}
failures: dict[str, Exception] = {}
valuinst_tests: dict[str, int] = {} # tests that produced VALUINST packets
for test_name, (instr_name, instructions) in INSTRUCTION_TESTS.items():
try:
all_runs, _, ops = run_with_retry(instructions)
all_runs, _, ops, valuinst_count = run_with_retry(instructions)
for op in ops:
if op not in discovered:
discovered[op] = set()
discovered[op].add(f"{test_name}")
if valuinst_count > 0:
valuinst_tests[test_name] = valuinst_count
if DEBUG >= 2:
print(f"\n{''*60}")
print(f"{test_name} ({instr_name}): ops={[hex(op) for op in sorted(ops)]}")
@@ -741,19 +768,20 @@ def discover_all_instops() -> tuple[dict[int, set[str]], dict[str, Exception]]:
print(f"\n=== traced simd={traced_simd} ===")
print_blobs(blobs, wave_only=False)
if DEBUG >= 1:
status = colored("", "green") if ops else colored("", "yellow")
status = colored("", "green") if ops else (colored("V", "cyan") if valuinst_count > 0 else colored("", "yellow"))
ops_str = ", ".join(hex(op) for op in sorted(ops)) if ops else "none"
print(f" {status} {test_name:25s} ops=[{ops_str}]")
valuinst_str = f" valuinst={valuinst_count}" if valuinst_count > 0 and not ops else ""
print(f" {status} {test_name:25s} ops=[{ops_str}]{valuinst_str}")
except Exception as e:
failures[test_name] = e
if DEBUG >= 1:
print(f" {colored('', 'red')} {test_name:25s} FAILED: {e}")
return discovered, failures
return discovered, failures, valuinst_tests
def print_summary(discovered: dict[int, set[str]], failures: dict[str, Exception]) -> None:
def print_summary(discovered: dict[int, set[str]], failures: dict[str, Exception], valuinst_tests: dict[str, int]) -> None:
"""Print discovery summary."""
known_ops = {e.value for e in InstOp}
discovered_ops = set(discovered.keys())
@@ -773,6 +801,15 @@ def print_summary(discovered: dict[int, set[str]], failures: dict[str, Exception
sources = ", ".join(sorted(discovered[op]))
print(f" 0x{op:02x} {name:20s} ({status}) <- {sources}")
# VALUINST tests (instructions that only produce VALUINST, not INST packets)
valuinst_only = {k: v for k, v in valuinst_tests.items() if not any(k in tests for tests in discovered.values())}
if valuinst_only:
print("\n" + "=" * 60)
print(colored("VALUINST-ONLY INSTRUCTIONS (no InstOp, use VALUINST packet)", "cyan"))
print("=" * 60)
for test_name, count in sorted(valuinst_only.items()):
print(f" {test_name}: {count} VALUINST packets")
# Missing from enum
missing = known_ops - discovered_ops
if missing:
@@ -805,6 +842,7 @@ def print_summary(discovered: dict[int, set[str]], failures: dict[str, Exception
if known_ops:
print(f" Coverage: {len(discovered_ops & known_ops)}/{len(known_ops)} ({100*len(discovered_ops & known_ops)//len(known_ops)}%)")
print(f" New ops found: {len(new_ops)}")
print(f" VALUINST-only: {len(valuinst_only)}")
if __name__ == "__main__":
@@ -813,5 +851,5 @@ if __name__ == "__main__":
print("=" * 60)
print(f"Testing {len(INSTRUCTION_TESTS)} instruction categories...\n")
discovered, failures = discover_all_instops()
print_summary(discovered, failures)
discovered, failures, valuinst_tests = discover_all_instops()
print_summary(discovered, failures, valuinst_tests)

View File

@@ -205,6 +205,19 @@ def get_inst_ops(packets: list, traced_simd: int | None = None) -> set:
in_wave = False
return ops
def count_valuinst(packets: list, traced_simd: int | None = None) -> int:
"""Count VALUINST packets within WAVESTART..WAVEEND on traced SIMD."""
count = 0
in_wave = False
for p in packets:
if isinstance(p, WAVESTART):
in_wave = traced_simd is None or p.simd == traced_simd
if in_wave and isinstance(p, VALUINST):
count += 1
if isinstance(p, WAVEEND):
in_wave = False
return count
# ═══════════════════════════════════════════════════════════════════════════════
# TESTS
# ═══════════════════════════════════════════════════════════════════════════════