From 267bbb163eedbf31ffbb9ccf1ef5c54da922cde9 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Thu, 1 Jan 2026 21:11:29 -0500 Subject: [PATCH] progress --- extra/assembly/amd/sqtt.py | 1 + extra/assembly/amd/test/discover_instops.py | 70 ++++++++++++++++----- extra/assembly/amd/test/test_sqtt_hw.py | 13 ++++ 3 files changed, 68 insertions(+), 16 deletions(-) diff --git a/extra/assembly/amd/sqtt.py b/extra/assembly/amd/sqtt.py index 2704efb417..08e8b95d4c 100644 --- a/extra/assembly/amd/sqtt.py +++ b/extra/assembly/amd/sqtt.py @@ -49,6 +49,7 @@ class InstOp(IntEnum): VALU_64_SHIFT = 0xd # 64-bit shifts: lshl, lshr, ashr VALU_MAD64 = 0xe # 64-bit multiply-add VALU_64 = 0xf # 64-bit: add, mul, fma, rcp, sqrt, rounding, frexp, div helpers + VINTERP = 0x12 # interpolation: v_interp_p10_f32, v_interp_p2_f32 # FLAT memory ops on traced SIMD (0x1x range) FLAT_LOAD = 0x1c diff --git a/extra/assembly/amd/test/discover_instops.py b/extra/assembly/amd/test/discover_instops.py index 263b76da39..fcb717e94a 100644 --- a/extra/assembly/amd/test/discover_instops.py +++ b/extra/assembly/amd/test/discover_instops.py @@ -90,6 +90,12 @@ from extra.assembly.amd.autogen.rdna3.ins import ( v_dot2_f16_f16, # WMMA v_wmma_f32_16x16x16_f16, v_wmma_f16_16x16x16_f16, v_wmma_i32_16x16x16_iu8, + # Permlane ops + v_permlane64_b32_e32, v_permlane16_b32, v_permlanex16_b32, + # Interpolation + v_interp_p10_f32, v_interp_p2_f32, + # Barrier + s_barrier, # SrcEnum for NULL soffset SrcEnum, ) @@ -97,7 +103,7 @@ from extra.assembly.amd.dsl import v, s from extra.assembly.amd.sqtt import InstOp, INST, WAVESTART, WAVEEND, ALUEXEC, VMEMEXEC from extra.assembly.amd.test.test_sqtt_hw import ( - run_asm_sqtt, decode_all_blobs, get_inst_ops, print_blobs, get_wave_packets, format_packet, PACKET_COLORS + run_asm_sqtt, decode_all_blobs, get_inst_ops, print_blobs, get_wave_packets, format_packet, PACKET_COLORS, count_valuinst ) # ═══════════════════════════════════════════════════════════════════════════════ @@ -256,6 +262,20 @@ INSTRUCTION_TESTS: dict[str, tuple[str, list]] = { "VALU_wmma_f16_f16": ("v_wmma_f16_16x16x16_f16", [v_wmma_f16_16x16x16_f16(v[0:7], v[8:15], v[16:23], v[0:7])]), "VALU_wmma_i32_iu8": ("v_wmma_i32_16x16x16_iu8", [v_wmma_i32_16x16x16_iu8(v[0:7], v[8:11], v[12:15], v[0:7])]), + # Permlane operations - cross-lane data movement + # NOTE: permlane64 produces NO SQTT packets in wave32 mode (it's for wave64 pairs) + # NOTE: permlane16/x16 produce VALUINST packets (no specific InstOp) + "VALU_permlane16": ("v_permlane16_b32", [v_permlane16_b32(v[0], v[1], s[2], s[3])]), + "VALU_permlanex16": ("v_permlanex16_b32", [v_permlanex16_b32(v[0], v[1], s[2], s[3])]), + + # Interpolation - used in graphics shaders (produces InstOp 0x12 VINTERP) + "VINTERP_p10": ("v_interp_p10_f32", [v_interp_p10_f32(v[0], v[1], v[2], v[3])]), + "VINTERP_p2": ("v_interp_p2_f32", [v_interp_p2_f32(v[0], v[1], v[2], v[3])]), + + # Barrier - wave synchronization + # NOTE: s_barrier produces NO SQTT instruction packets (with 1 wave, it's essentially a no-op) + "SALU_barrier": ("s_barrier", [s_barrier()]), + # LDS atomics "LDS_atomic_add": ("ds_add_u32", [ v_mov_b32_e32(v[0], 0), # LDS address @@ -662,46 +682,53 @@ INSTRUCTION_TESTS: dict[str, tuple[str, list]] = { } -def run_with_retry(instructions: list, max_attempts: int = 20) -> tuple[list[tuple[int, list[bytes]]], list[list], set]: +def run_with_retry(instructions: list, max_attempts: int = 20) -> tuple[list[tuple[int, list[bytes]]], list[list], set, int]: """Run instructions multiple times to collect InstOp variants. Memory ops produce different InstOp values (0x2x vs 0x5x) depending on which SIMD executes them: - 0x2x range: wave ran on traced SIMD (matched) - 0x5x range: wave ran on other SIMD (not matched) - Returns list of (traced_simd, blobs) tuples. + Returns list of (traced_simd, blobs) tuples, all_packets, all_ops, max_valuinst_count. """ all_ops = set() all_runs: list[tuple[int, list[bytes]]] = [] all_packets = [] + max_valuinst = 0 SQTT_SIMD_SEL.value = 0 # only trace SIMD 0 for _ in range(max_attempts): blobs = run_asm_sqtt(instructions) packets = decode_all_blobs(blobs) - # get ops from waves on traced SIMD 0 (gives 0x2x range) - ops = get_inst_ops(packets, traced_simd=0) - # also get ops from waves on other SIMDs (gives 0x5x range for memory ops) - for simd in [1, 2, 3]: + # get ops and valuinst from all SIMDs + ops = set() + valuinst_count = 0 + for simd in [0, 1, 2, 3]: ops.update(get_inst_ops(packets, traced_simd=simd)) + valuinst_count = max(valuinst_count, count_valuinst(packets, traced_simd=simd)) all_runs.append((0, blobs)) all_packets.append(packets) all_ops.update(ops) - return all_runs, all_packets, all_ops + max_valuinst = max(max_valuinst, valuinst_count) + return all_runs, all_packets, all_ops, max_valuinst -def discover_all_instops() -> tuple[dict[int, set[str]], dict[str, Exception]]: +def discover_all_instops() -> tuple[dict[int, set[str]], dict[str, Exception], dict[str, int]]: """Run all instruction tests and collect InstOp values.""" discovered: dict[int, set[str]] = {} failures: dict[str, Exception] = {} + valuinst_tests: dict[str, int] = {} # tests that produced VALUINST packets for test_name, (instr_name, instructions) in INSTRUCTION_TESTS.items(): try: - all_runs, _, ops = run_with_retry(instructions) + all_runs, _, ops, valuinst_count = run_with_retry(instructions) for op in ops: if op not in discovered: discovered[op] = set() discovered[op].add(f"{test_name}") + if valuinst_count > 0: + valuinst_tests[test_name] = valuinst_count + if DEBUG >= 2: print(f"\n{'─'*60}") print(f"{test_name} ({instr_name}): ops={[hex(op) for op in sorted(ops)]}") @@ -741,19 +768,20 @@ def discover_all_instops() -> tuple[dict[int, set[str]], dict[str, Exception]]: print(f"\n=== traced simd={traced_simd} ===") print_blobs(blobs, wave_only=False) if DEBUG >= 1: - status = colored("✓", "green") if ops else colored("∅", "yellow") + status = colored("✓", "green") if ops else (colored("V", "cyan") if valuinst_count > 0 else colored("∅", "yellow")) ops_str = ", ".join(hex(op) for op in sorted(ops)) if ops else "none" - print(f" {status} {test_name:25s} ops=[{ops_str}]") + valuinst_str = f" valuinst={valuinst_count}" if valuinst_count > 0 and not ops else "" + print(f" {status} {test_name:25s} ops=[{ops_str}]{valuinst_str}") except Exception as e: failures[test_name] = e if DEBUG >= 1: print(f" {colored('✗', 'red')} {test_name:25s} FAILED: {e}") - return discovered, failures + return discovered, failures, valuinst_tests -def print_summary(discovered: dict[int, set[str]], failures: dict[str, Exception]) -> None: +def print_summary(discovered: dict[int, set[str]], failures: dict[str, Exception], valuinst_tests: dict[str, int]) -> None: """Print discovery summary.""" known_ops = {e.value for e in InstOp} discovered_ops = set(discovered.keys()) @@ -773,6 +801,15 @@ def print_summary(discovered: dict[int, set[str]], failures: dict[str, Exception sources = ", ".join(sorted(discovered[op])) print(f" 0x{op:02x} {name:20s} ({status}) <- {sources}") + # VALUINST tests (instructions that only produce VALUINST, not INST packets) + valuinst_only = {k: v for k, v in valuinst_tests.items() if not any(k in tests for tests in discovered.values())} + if valuinst_only: + print("\n" + "=" * 60) + print(colored("VALUINST-ONLY INSTRUCTIONS (no InstOp, use VALUINST packet)", "cyan")) + print("=" * 60) + for test_name, count in sorted(valuinst_only.items()): + print(f" {test_name}: {count} VALUINST packets") + # Missing from enum missing = known_ops - discovered_ops if missing: @@ -805,6 +842,7 @@ def print_summary(discovered: dict[int, set[str]], failures: dict[str, Exception if known_ops: print(f" Coverage: {len(discovered_ops & known_ops)}/{len(known_ops)} ({100*len(discovered_ops & known_ops)//len(known_ops)}%)") print(f" New ops found: {len(new_ops)}") + print(f" VALUINST-only: {len(valuinst_only)}") if __name__ == "__main__": @@ -813,5 +851,5 @@ if __name__ == "__main__": print("=" * 60) print(f"Testing {len(INSTRUCTION_TESTS)} instruction categories...\n") - discovered, failures = discover_all_instops() - print_summary(discovered, failures) + discovered, failures, valuinst_tests = discover_all_instops() + print_summary(discovered, failures, valuinst_tests) diff --git a/extra/assembly/amd/test/test_sqtt_hw.py b/extra/assembly/amd/test/test_sqtt_hw.py index 718e708cce..728619b78a 100644 --- a/extra/assembly/amd/test/test_sqtt_hw.py +++ b/extra/assembly/amd/test/test_sqtt_hw.py @@ -205,6 +205,19 @@ def get_inst_ops(packets: list, traced_simd: int | None = None) -> set: in_wave = False return ops +def count_valuinst(packets: list, traced_simd: int | None = None) -> int: + """Count VALUINST packets within WAVESTART..WAVEEND on traced SIMD.""" + count = 0 + in_wave = False + for p in packets: + if isinstance(p, WAVESTART): + in_wave = traced_simd is None or p.simd == traced_simd + if in_wave and isinstance(p, VALUINST): + count += 1 + if isinstance(p, WAVEEND): + in_wave = False + return count + # ═══════════════════════════════════════════════════════════════════════════════ # TESTS # ═══════════════════════════════════════════════════════════════════════════════