mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
progress
This commit is contained in:
@@ -49,6 +49,7 @@ class InstOp(IntEnum):
|
|||||||
VALU_64_SHIFT = 0xd # 64-bit shifts: lshl, lshr, ashr
|
VALU_64_SHIFT = 0xd # 64-bit shifts: lshl, lshr, ashr
|
||||||
VALU_MAD64 = 0xe # 64-bit multiply-add
|
VALU_MAD64 = 0xe # 64-bit multiply-add
|
||||||
VALU_64 = 0xf # 64-bit: add, mul, fma, rcp, sqrt, rounding, frexp, div helpers
|
VALU_64 = 0xf # 64-bit: add, mul, fma, rcp, sqrt, rounding, frexp, div helpers
|
||||||
|
VINTERP = 0x12 # interpolation: v_interp_p10_f32, v_interp_p2_f32
|
||||||
|
|
||||||
# FLAT memory ops on traced SIMD (0x1x range)
|
# FLAT memory ops on traced SIMD (0x1x range)
|
||||||
FLAT_LOAD = 0x1c
|
FLAT_LOAD = 0x1c
|
||||||
|
|||||||
@@ -90,6 +90,12 @@ from extra.assembly.amd.autogen.rdna3.ins import (
|
|||||||
v_dot2_f16_f16,
|
v_dot2_f16_f16,
|
||||||
# WMMA
|
# WMMA
|
||||||
v_wmma_f32_16x16x16_f16, v_wmma_f16_16x16x16_f16, v_wmma_i32_16x16x16_iu8,
|
v_wmma_f32_16x16x16_f16, v_wmma_f16_16x16x16_f16, v_wmma_i32_16x16x16_iu8,
|
||||||
|
# Permlane ops
|
||||||
|
v_permlane64_b32_e32, v_permlane16_b32, v_permlanex16_b32,
|
||||||
|
# Interpolation
|
||||||
|
v_interp_p10_f32, v_interp_p2_f32,
|
||||||
|
# Barrier
|
||||||
|
s_barrier,
|
||||||
# SrcEnum for NULL soffset
|
# SrcEnum for NULL soffset
|
||||||
SrcEnum,
|
SrcEnum,
|
||||||
)
|
)
|
||||||
@@ -97,7 +103,7 @@ from extra.assembly.amd.dsl import v, s
|
|||||||
from extra.assembly.amd.sqtt import InstOp, INST, WAVESTART, WAVEEND, ALUEXEC, VMEMEXEC
|
from extra.assembly.amd.sqtt import InstOp, INST, WAVESTART, WAVEEND, ALUEXEC, VMEMEXEC
|
||||||
|
|
||||||
from extra.assembly.amd.test.test_sqtt_hw import (
|
from extra.assembly.amd.test.test_sqtt_hw import (
|
||||||
run_asm_sqtt, decode_all_blobs, get_inst_ops, print_blobs, get_wave_packets, format_packet, PACKET_COLORS
|
run_asm_sqtt, decode_all_blobs, get_inst_ops, print_blobs, get_wave_packets, format_packet, PACKET_COLORS, count_valuinst
|
||||||
)
|
)
|
||||||
|
|
||||||
# ═══════════════════════════════════════════════════════════════════════════════
|
# ═══════════════════════════════════════════════════════════════════════════════
|
||||||
@@ -256,6 +262,20 @@ INSTRUCTION_TESTS: dict[str, tuple[str, list]] = {
|
|||||||
"VALU_wmma_f16_f16": ("v_wmma_f16_16x16x16_f16", [v_wmma_f16_16x16x16_f16(v[0:7], v[8:15], v[16:23], v[0:7])]),
|
"VALU_wmma_f16_f16": ("v_wmma_f16_16x16x16_f16", [v_wmma_f16_16x16x16_f16(v[0:7], v[8:15], v[16:23], v[0:7])]),
|
||||||
"VALU_wmma_i32_iu8": ("v_wmma_i32_16x16x16_iu8", [v_wmma_i32_16x16x16_iu8(v[0:7], v[8:11], v[12:15], v[0:7])]),
|
"VALU_wmma_i32_iu8": ("v_wmma_i32_16x16x16_iu8", [v_wmma_i32_16x16x16_iu8(v[0:7], v[8:11], v[12:15], v[0:7])]),
|
||||||
|
|
||||||
|
# Permlane operations - cross-lane data movement
|
||||||
|
# NOTE: permlane64 produces NO SQTT packets in wave32 mode (it's for wave64 pairs)
|
||||||
|
# NOTE: permlane16/x16 produce VALUINST packets (no specific InstOp)
|
||||||
|
"VALU_permlane16": ("v_permlane16_b32", [v_permlane16_b32(v[0], v[1], s[2], s[3])]),
|
||||||
|
"VALU_permlanex16": ("v_permlanex16_b32", [v_permlanex16_b32(v[0], v[1], s[2], s[3])]),
|
||||||
|
|
||||||
|
# Interpolation - used in graphics shaders (produces InstOp 0x12 VINTERP)
|
||||||
|
"VINTERP_p10": ("v_interp_p10_f32", [v_interp_p10_f32(v[0], v[1], v[2], v[3])]),
|
||||||
|
"VINTERP_p2": ("v_interp_p2_f32", [v_interp_p2_f32(v[0], v[1], v[2], v[3])]),
|
||||||
|
|
||||||
|
# Barrier - wave synchronization
|
||||||
|
# NOTE: s_barrier produces NO SQTT instruction packets (with 1 wave, it's essentially a no-op)
|
||||||
|
"SALU_barrier": ("s_barrier", [s_barrier()]),
|
||||||
|
|
||||||
# LDS atomics
|
# LDS atomics
|
||||||
"LDS_atomic_add": ("ds_add_u32", [
|
"LDS_atomic_add": ("ds_add_u32", [
|
||||||
v_mov_b32_e32(v[0], 0), # LDS address
|
v_mov_b32_e32(v[0], 0), # LDS address
|
||||||
@@ -662,46 +682,53 @@ INSTRUCTION_TESTS: dict[str, tuple[str, list]] = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def run_with_retry(instructions: list, max_attempts: int = 20) -> tuple[list[tuple[int, list[bytes]]], list[list], set]:
|
def run_with_retry(instructions: list, max_attempts: int = 20) -> tuple[list[tuple[int, list[bytes]]], list[list], set, int]:
|
||||||
"""Run instructions multiple times to collect InstOp variants.
|
"""Run instructions multiple times to collect InstOp variants.
|
||||||
|
|
||||||
Memory ops produce different InstOp values (0x2x vs 0x5x) depending on which SIMD executes them:
|
Memory ops produce different InstOp values (0x2x vs 0x5x) depending on which SIMD executes them:
|
||||||
- 0x2x range: wave ran on traced SIMD (matched)
|
- 0x2x range: wave ran on traced SIMD (matched)
|
||||||
- 0x5x range: wave ran on other SIMD (not matched)
|
- 0x5x range: wave ran on other SIMD (not matched)
|
||||||
|
|
||||||
Returns list of (traced_simd, blobs) tuples.
|
Returns list of (traced_simd, blobs) tuples, all_packets, all_ops, max_valuinst_count.
|
||||||
"""
|
"""
|
||||||
all_ops = set()
|
all_ops = set()
|
||||||
all_runs: list[tuple[int, list[bytes]]] = []
|
all_runs: list[tuple[int, list[bytes]]] = []
|
||||||
all_packets = []
|
all_packets = []
|
||||||
|
max_valuinst = 0
|
||||||
SQTT_SIMD_SEL.value = 0 # only trace SIMD 0
|
SQTT_SIMD_SEL.value = 0 # only trace SIMD 0
|
||||||
for _ in range(max_attempts):
|
for _ in range(max_attempts):
|
||||||
blobs = run_asm_sqtt(instructions)
|
blobs = run_asm_sqtt(instructions)
|
||||||
packets = decode_all_blobs(blobs)
|
packets = decode_all_blobs(blobs)
|
||||||
# get ops from waves on traced SIMD 0 (gives 0x2x range)
|
# get ops and valuinst from all SIMDs
|
||||||
ops = get_inst_ops(packets, traced_simd=0)
|
ops = set()
|
||||||
# also get ops from waves on other SIMDs (gives 0x5x range for memory ops)
|
valuinst_count = 0
|
||||||
for simd in [1, 2, 3]:
|
for simd in [0, 1, 2, 3]:
|
||||||
ops.update(get_inst_ops(packets, traced_simd=simd))
|
ops.update(get_inst_ops(packets, traced_simd=simd))
|
||||||
|
valuinst_count = max(valuinst_count, count_valuinst(packets, traced_simd=simd))
|
||||||
all_runs.append((0, blobs))
|
all_runs.append((0, blobs))
|
||||||
all_packets.append(packets)
|
all_packets.append(packets)
|
||||||
all_ops.update(ops)
|
all_ops.update(ops)
|
||||||
return all_runs, all_packets, all_ops
|
max_valuinst = max(max_valuinst, valuinst_count)
|
||||||
|
return all_runs, all_packets, all_ops, max_valuinst
|
||||||
|
|
||||||
def discover_all_instops() -> tuple[dict[int, set[str]], dict[str, Exception]]:
|
def discover_all_instops() -> tuple[dict[int, set[str]], dict[str, Exception], dict[str, int]]:
|
||||||
"""Run all instruction tests and collect InstOp values."""
|
"""Run all instruction tests and collect InstOp values."""
|
||||||
discovered: dict[int, set[str]] = {}
|
discovered: dict[int, set[str]] = {}
|
||||||
failures: dict[str, Exception] = {}
|
failures: dict[str, Exception] = {}
|
||||||
|
valuinst_tests: dict[str, int] = {} # tests that produced VALUINST packets
|
||||||
|
|
||||||
for test_name, (instr_name, instructions) in INSTRUCTION_TESTS.items():
|
for test_name, (instr_name, instructions) in INSTRUCTION_TESTS.items():
|
||||||
try:
|
try:
|
||||||
all_runs, _, ops = run_with_retry(instructions)
|
all_runs, _, ops, valuinst_count = run_with_retry(instructions)
|
||||||
|
|
||||||
for op in ops:
|
for op in ops:
|
||||||
if op not in discovered:
|
if op not in discovered:
|
||||||
discovered[op] = set()
|
discovered[op] = set()
|
||||||
discovered[op].add(f"{test_name}")
|
discovered[op].add(f"{test_name}")
|
||||||
|
|
||||||
|
if valuinst_count > 0:
|
||||||
|
valuinst_tests[test_name] = valuinst_count
|
||||||
|
|
||||||
if DEBUG >= 2:
|
if DEBUG >= 2:
|
||||||
print(f"\n{'─'*60}")
|
print(f"\n{'─'*60}")
|
||||||
print(f"{test_name} ({instr_name}): ops={[hex(op) for op in sorted(ops)]}")
|
print(f"{test_name} ({instr_name}): ops={[hex(op) for op in sorted(ops)]}")
|
||||||
@@ -741,19 +768,20 @@ def discover_all_instops() -> tuple[dict[int, set[str]], dict[str, Exception]]:
|
|||||||
print(f"\n=== traced simd={traced_simd} ===")
|
print(f"\n=== traced simd={traced_simd} ===")
|
||||||
print_blobs(blobs, wave_only=False)
|
print_blobs(blobs, wave_only=False)
|
||||||
if DEBUG >= 1:
|
if DEBUG >= 1:
|
||||||
status = colored("✓", "green") if ops else colored("∅", "yellow")
|
status = colored("✓", "green") if ops else (colored("V", "cyan") if valuinst_count > 0 else colored("∅", "yellow"))
|
||||||
ops_str = ", ".join(hex(op) for op in sorted(ops)) if ops else "none"
|
ops_str = ", ".join(hex(op) for op in sorted(ops)) if ops else "none"
|
||||||
print(f" {status} {test_name:25s} ops=[{ops_str}]")
|
valuinst_str = f" valuinst={valuinst_count}" if valuinst_count > 0 and not ops else ""
|
||||||
|
print(f" {status} {test_name:25s} ops=[{ops_str}]{valuinst_str}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
failures[test_name] = e
|
failures[test_name] = e
|
||||||
if DEBUG >= 1:
|
if DEBUG >= 1:
|
||||||
print(f" {colored('✗', 'red')} {test_name:25s} FAILED: {e}")
|
print(f" {colored('✗', 'red')} {test_name:25s} FAILED: {e}")
|
||||||
|
|
||||||
return discovered, failures
|
return discovered, failures, valuinst_tests
|
||||||
|
|
||||||
|
|
||||||
def print_summary(discovered: dict[int, set[str]], failures: dict[str, Exception]) -> None:
|
def print_summary(discovered: dict[int, set[str]], failures: dict[str, Exception], valuinst_tests: dict[str, int]) -> None:
|
||||||
"""Print discovery summary."""
|
"""Print discovery summary."""
|
||||||
known_ops = {e.value for e in InstOp}
|
known_ops = {e.value for e in InstOp}
|
||||||
discovered_ops = set(discovered.keys())
|
discovered_ops = set(discovered.keys())
|
||||||
@@ -773,6 +801,15 @@ def print_summary(discovered: dict[int, set[str]], failures: dict[str, Exception
|
|||||||
sources = ", ".join(sorted(discovered[op]))
|
sources = ", ".join(sorted(discovered[op]))
|
||||||
print(f" 0x{op:02x} {name:20s} ({status}) <- {sources}")
|
print(f" 0x{op:02x} {name:20s} ({status}) <- {sources}")
|
||||||
|
|
||||||
|
# VALUINST tests (instructions that only produce VALUINST, not INST packets)
|
||||||
|
valuinst_only = {k: v for k, v in valuinst_tests.items() if not any(k in tests for tests in discovered.values())}
|
||||||
|
if valuinst_only:
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print(colored("VALUINST-ONLY INSTRUCTIONS (no InstOp, use VALUINST packet)", "cyan"))
|
||||||
|
print("=" * 60)
|
||||||
|
for test_name, count in sorted(valuinst_only.items()):
|
||||||
|
print(f" {test_name}: {count} VALUINST packets")
|
||||||
|
|
||||||
# Missing from enum
|
# Missing from enum
|
||||||
missing = known_ops - discovered_ops
|
missing = known_ops - discovered_ops
|
||||||
if missing:
|
if missing:
|
||||||
@@ -805,6 +842,7 @@ def print_summary(discovered: dict[int, set[str]], failures: dict[str, Exception
|
|||||||
if known_ops:
|
if known_ops:
|
||||||
print(f" Coverage: {len(discovered_ops & known_ops)}/{len(known_ops)} ({100*len(discovered_ops & known_ops)//len(known_ops)}%)")
|
print(f" Coverage: {len(discovered_ops & known_ops)}/{len(known_ops)} ({100*len(discovered_ops & known_ops)//len(known_ops)}%)")
|
||||||
print(f" New ops found: {len(new_ops)}")
|
print(f" New ops found: {len(new_ops)}")
|
||||||
|
print(f" VALUINST-only: {len(valuinst_only)}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@@ -813,5 +851,5 @@ if __name__ == "__main__":
|
|||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
print(f"Testing {len(INSTRUCTION_TESTS)} instruction categories...\n")
|
print(f"Testing {len(INSTRUCTION_TESTS)} instruction categories...\n")
|
||||||
|
|
||||||
discovered, failures = discover_all_instops()
|
discovered, failures, valuinst_tests = discover_all_instops()
|
||||||
print_summary(discovered, failures)
|
print_summary(discovered, failures, valuinst_tests)
|
||||||
|
|||||||
@@ -205,6 +205,19 @@ def get_inst_ops(packets: list, traced_simd: int | None = None) -> set:
|
|||||||
in_wave = False
|
in_wave = False
|
||||||
return ops
|
return ops
|
||||||
|
|
||||||
|
def count_valuinst(packets: list, traced_simd: int | None = None) -> int:
|
||||||
|
"""Count VALUINST packets within WAVESTART..WAVEEND on traced SIMD."""
|
||||||
|
count = 0
|
||||||
|
in_wave = False
|
||||||
|
for p in packets:
|
||||||
|
if isinstance(p, WAVESTART):
|
||||||
|
in_wave = traced_simd is None or p.simd == traced_simd
|
||||||
|
if in_wave and isinstance(p, VALUINST):
|
||||||
|
count += 1
|
||||||
|
if isinstance(p, WAVEEND):
|
||||||
|
in_wave = False
|
||||||
|
return count
|
||||||
|
|
||||||
# ═══════════════════════════════════════════════════════════════════════════════
|
# ═══════════════════════════════════════════════════════════════════════════════
|
||||||
# TESTS
|
# TESTS
|
||||||
# ═══════════════════════════════════════════════════════════════════════════════
|
# ═══════════════════════════════════════════════════════════════════════════════
|
||||||
|
|||||||
Reference in New Issue
Block a user