mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
progress
This commit is contained in:
@@ -49,6 +49,7 @@ class InstOp(IntEnum):
|
||||
VALU_64_SHIFT = 0xd # 64-bit shifts: lshl, lshr, ashr
|
||||
VALU_MAD64 = 0xe # 64-bit multiply-add
|
||||
VALU_64 = 0xf # 64-bit: add, mul, fma, rcp, sqrt, rounding, frexp, div helpers
|
||||
VINTERP = 0x12 # interpolation: v_interp_p10_f32, v_interp_p2_f32
|
||||
|
||||
# FLAT memory ops on traced SIMD (0x1x range)
|
||||
FLAT_LOAD = 0x1c
|
||||
|
||||
@@ -90,6 +90,12 @@ from extra.assembly.amd.autogen.rdna3.ins import (
|
||||
v_dot2_f16_f16,
|
||||
# WMMA
|
||||
v_wmma_f32_16x16x16_f16, v_wmma_f16_16x16x16_f16, v_wmma_i32_16x16x16_iu8,
|
||||
# Permlane ops
|
||||
v_permlane64_b32_e32, v_permlane16_b32, v_permlanex16_b32,
|
||||
# Interpolation
|
||||
v_interp_p10_f32, v_interp_p2_f32,
|
||||
# Barrier
|
||||
s_barrier,
|
||||
# SrcEnum for NULL soffset
|
||||
SrcEnum,
|
||||
)
|
||||
@@ -97,7 +103,7 @@ from extra.assembly.amd.dsl import v, s
|
||||
from extra.assembly.amd.sqtt import InstOp, INST, WAVESTART, WAVEEND, ALUEXEC, VMEMEXEC
|
||||
|
||||
from extra.assembly.amd.test.test_sqtt_hw import (
|
||||
run_asm_sqtt, decode_all_blobs, get_inst_ops, print_blobs, get_wave_packets, format_packet, PACKET_COLORS
|
||||
run_asm_sqtt, decode_all_blobs, get_inst_ops, print_blobs, get_wave_packets, format_packet, PACKET_COLORS, count_valuinst
|
||||
)
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
@@ -256,6 +262,20 @@ INSTRUCTION_TESTS: dict[str, tuple[str, list]] = {
|
||||
"VALU_wmma_f16_f16": ("v_wmma_f16_16x16x16_f16", [v_wmma_f16_16x16x16_f16(v[0:7], v[8:15], v[16:23], v[0:7])]),
|
||||
"VALU_wmma_i32_iu8": ("v_wmma_i32_16x16x16_iu8", [v_wmma_i32_16x16x16_iu8(v[0:7], v[8:11], v[12:15], v[0:7])]),
|
||||
|
||||
# Permlane operations - cross-lane data movement
|
||||
# NOTE: permlane64 produces NO SQTT packets in wave32 mode (it's for wave64 pairs)
|
||||
# NOTE: permlane16/x16 produce VALUINST packets (no specific InstOp)
|
||||
"VALU_permlane16": ("v_permlane16_b32", [v_permlane16_b32(v[0], v[1], s[2], s[3])]),
|
||||
"VALU_permlanex16": ("v_permlanex16_b32", [v_permlanex16_b32(v[0], v[1], s[2], s[3])]),
|
||||
|
||||
# Interpolation - used in graphics shaders (produces InstOp 0x12 VINTERP)
|
||||
"VINTERP_p10": ("v_interp_p10_f32", [v_interp_p10_f32(v[0], v[1], v[2], v[3])]),
|
||||
"VINTERP_p2": ("v_interp_p2_f32", [v_interp_p2_f32(v[0], v[1], v[2], v[3])]),
|
||||
|
||||
# Barrier - wave synchronization
|
||||
# NOTE: s_barrier produces NO SQTT instruction packets (with 1 wave, it's essentially a no-op)
|
||||
"SALU_barrier": ("s_barrier", [s_barrier()]),
|
||||
|
||||
# LDS atomics
|
||||
"LDS_atomic_add": ("ds_add_u32", [
|
||||
v_mov_b32_e32(v[0], 0), # LDS address
|
||||
@@ -662,46 +682,53 @@ INSTRUCTION_TESTS: dict[str, tuple[str, list]] = {
|
||||
}
|
||||
|
||||
|
||||
def run_with_retry(instructions: list, max_attempts: int = 20) -> tuple[list[tuple[int, list[bytes]]], list[list], set]:
|
||||
def run_with_retry(instructions: list, max_attempts: int = 20) -> tuple[list[tuple[int, list[bytes]]], list[list], set, int]:
|
||||
"""Run instructions multiple times to collect InstOp variants.
|
||||
|
||||
Memory ops produce different InstOp values (0x2x vs 0x5x) depending on which SIMD executes them:
|
||||
- 0x2x range: wave ran on traced SIMD (matched)
|
||||
- 0x5x range: wave ran on other SIMD (not matched)
|
||||
|
||||
Returns list of (traced_simd, blobs) tuples.
|
||||
Returns list of (traced_simd, blobs) tuples, all_packets, all_ops, max_valuinst_count.
|
||||
"""
|
||||
all_ops = set()
|
||||
all_runs: list[tuple[int, list[bytes]]] = []
|
||||
all_packets = []
|
||||
max_valuinst = 0
|
||||
SQTT_SIMD_SEL.value = 0 # only trace SIMD 0
|
||||
for _ in range(max_attempts):
|
||||
blobs = run_asm_sqtt(instructions)
|
||||
packets = decode_all_blobs(blobs)
|
||||
# get ops from waves on traced SIMD 0 (gives 0x2x range)
|
||||
ops = get_inst_ops(packets, traced_simd=0)
|
||||
# also get ops from waves on other SIMDs (gives 0x5x range for memory ops)
|
||||
for simd in [1, 2, 3]:
|
||||
# get ops and valuinst from all SIMDs
|
||||
ops = set()
|
||||
valuinst_count = 0
|
||||
for simd in [0, 1, 2, 3]:
|
||||
ops.update(get_inst_ops(packets, traced_simd=simd))
|
||||
valuinst_count = max(valuinst_count, count_valuinst(packets, traced_simd=simd))
|
||||
all_runs.append((0, blobs))
|
||||
all_packets.append(packets)
|
||||
all_ops.update(ops)
|
||||
return all_runs, all_packets, all_ops
|
||||
max_valuinst = max(max_valuinst, valuinst_count)
|
||||
return all_runs, all_packets, all_ops, max_valuinst
|
||||
|
||||
def discover_all_instops() -> tuple[dict[int, set[str]], dict[str, Exception]]:
|
||||
def discover_all_instops() -> tuple[dict[int, set[str]], dict[str, Exception], dict[str, int]]:
|
||||
"""Run all instruction tests and collect InstOp values."""
|
||||
discovered: dict[int, set[str]] = {}
|
||||
failures: dict[str, Exception] = {}
|
||||
valuinst_tests: dict[str, int] = {} # tests that produced VALUINST packets
|
||||
|
||||
for test_name, (instr_name, instructions) in INSTRUCTION_TESTS.items():
|
||||
try:
|
||||
all_runs, _, ops = run_with_retry(instructions)
|
||||
all_runs, _, ops, valuinst_count = run_with_retry(instructions)
|
||||
|
||||
for op in ops:
|
||||
if op not in discovered:
|
||||
discovered[op] = set()
|
||||
discovered[op].add(f"{test_name}")
|
||||
|
||||
if valuinst_count > 0:
|
||||
valuinst_tests[test_name] = valuinst_count
|
||||
|
||||
if DEBUG >= 2:
|
||||
print(f"\n{'─'*60}")
|
||||
print(f"{test_name} ({instr_name}): ops={[hex(op) for op in sorted(ops)]}")
|
||||
@@ -741,19 +768,20 @@ def discover_all_instops() -> tuple[dict[int, set[str]], dict[str, Exception]]:
|
||||
print(f"\n=== traced simd={traced_simd} ===")
|
||||
print_blobs(blobs, wave_only=False)
|
||||
if DEBUG >= 1:
|
||||
status = colored("✓", "green") if ops else colored("∅", "yellow")
|
||||
status = colored("✓", "green") if ops else (colored("V", "cyan") if valuinst_count > 0 else colored("∅", "yellow"))
|
||||
ops_str = ", ".join(hex(op) for op in sorted(ops)) if ops else "none"
|
||||
print(f" {status} {test_name:25s} ops=[{ops_str}]")
|
||||
valuinst_str = f" valuinst={valuinst_count}" if valuinst_count > 0 and not ops else ""
|
||||
print(f" {status} {test_name:25s} ops=[{ops_str}]{valuinst_str}")
|
||||
|
||||
except Exception as e:
|
||||
failures[test_name] = e
|
||||
if DEBUG >= 1:
|
||||
print(f" {colored('✗', 'red')} {test_name:25s} FAILED: {e}")
|
||||
|
||||
return discovered, failures
|
||||
return discovered, failures, valuinst_tests
|
||||
|
||||
|
||||
def print_summary(discovered: dict[int, set[str]], failures: dict[str, Exception]) -> None:
|
||||
def print_summary(discovered: dict[int, set[str]], failures: dict[str, Exception], valuinst_tests: dict[str, int]) -> None:
|
||||
"""Print discovery summary."""
|
||||
known_ops = {e.value for e in InstOp}
|
||||
discovered_ops = set(discovered.keys())
|
||||
@@ -773,6 +801,15 @@ def print_summary(discovered: dict[int, set[str]], failures: dict[str, Exception
|
||||
sources = ", ".join(sorted(discovered[op]))
|
||||
print(f" 0x{op:02x} {name:20s} ({status}) <- {sources}")
|
||||
|
||||
# VALUINST tests (instructions that only produce VALUINST, not INST packets)
|
||||
valuinst_only = {k: v for k, v in valuinst_tests.items() if not any(k in tests for tests in discovered.values())}
|
||||
if valuinst_only:
|
||||
print("\n" + "=" * 60)
|
||||
print(colored("VALUINST-ONLY INSTRUCTIONS (no InstOp, use VALUINST packet)", "cyan"))
|
||||
print("=" * 60)
|
||||
for test_name, count in sorted(valuinst_only.items()):
|
||||
print(f" {test_name}: {count} VALUINST packets")
|
||||
|
||||
# Missing from enum
|
||||
missing = known_ops - discovered_ops
|
||||
if missing:
|
||||
@@ -805,6 +842,7 @@ def print_summary(discovered: dict[int, set[str]], failures: dict[str, Exception
|
||||
if known_ops:
|
||||
print(f" Coverage: {len(discovered_ops & known_ops)}/{len(known_ops)} ({100*len(discovered_ops & known_ops)//len(known_ops)}%)")
|
||||
print(f" New ops found: {len(new_ops)}")
|
||||
print(f" VALUINST-only: {len(valuinst_only)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -813,5 +851,5 @@ if __name__ == "__main__":
|
||||
print("=" * 60)
|
||||
print(f"Testing {len(INSTRUCTION_TESTS)} instruction categories...\n")
|
||||
|
||||
discovered, failures = discover_all_instops()
|
||||
print_summary(discovered, failures)
|
||||
discovered, failures, valuinst_tests = discover_all_instops()
|
||||
print_summary(discovered, failures, valuinst_tests)
|
||||
|
||||
@@ -205,6 +205,19 @@ def get_inst_ops(packets: list, traced_simd: int | None = None) -> set:
|
||||
in_wave = False
|
||||
return ops
|
||||
|
||||
def count_valuinst(packets: list, traced_simd: int | None = None) -> int:
|
||||
"""Count VALUINST packets within WAVESTART..WAVEEND on traced SIMD."""
|
||||
count = 0
|
||||
in_wave = False
|
||||
for p in packets:
|
||||
if isinstance(p, WAVESTART):
|
||||
in_wave = traced_simd is None or p.simd == traced_simd
|
||||
if in_wave and isinstance(p, VALUINST):
|
||||
count += 1
|
||||
if isinstance(p, WAVEEND):
|
||||
in_wave = False
|
||||
return count
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# TESTS
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
Reference in New Issue
Block a user