From dc977a03b0dfa79a475fe504976064a2d0f4271a Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Fri, 30 Jan 2026 00:12:39 +0300 Subject: [PATCH] nv_pma: bw decoder (#14424) * nv_pma: bw decoder * decoder fix * better --- extra/nv_pma/decode.py | 119 ++++++++++++++++++++++--------- extra/nv_pma/test/test_nvprof.py | 36 +++++++--- 2 files changed, 113 insertions(+), 42 deletions(-) diff --git a/extra/nv_pma/decode.py b/extra/nv_pma/decode.py index 4871292597..2bdb1dd242 100644 --- a/extra/nv_pma/decode.py +++ b/extra/nv_pma/decode.py @@ -25,7 +25,7 @@ class StallReason(enum.IntEnum): OTHER = 11 # misc, dispatch_stall SLEEPING = 12 # sleeping -STALL_KEY_MAP: dict[int, StallReason] = { +STALL_KEY_MAP_AMPERE: dict[int, StallReason] = { 1: StallReason.MEMORY_THROTTLE, 15: StallReason.MEMORY_THROTTLE, 2: StallReason.CONSTANT_MEMORY, 3: StallReason.SYNC, @@ -37,14 +37,25 @@ STALL_KEY_MAP: dict[int, StallReason] = { 18: StallReason.NONE, } +STALL_KEY_MAP_BLACKWELL: dict[int, StallReason] = { + 0x01: StallReason.MEMORY_THROTTLE, 0x0e: StallReason.MEMORY_THROTTLE, + 0x02: StallReason.SYNC, + 0x05: StallReason.INST_FETCH, 0x0a: StallReason.INST_FETCH, + 0x06: StallReason.EXEC_DEPENDENCY, 0x09: StallReason.EXEC_DEPENDENCY, + 0x08: StallReason.MEMORY_DEPENDENCY, + 0x0b: StallReason.PIPE_BUSY, 0x0f: StallReason.PIPE_BUSY, + 0x10: StallReason.OTHER, 0x13: StallReason.OTHER, + 0x11: StallReason.NONE, +} + +# Lookup table for extracting sample bytes from 32-byte packet (bytes 0-3, 8-31, skipping header at 4-7) +LOOKUP_28B = [0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] + # ═══════════════════════════════════════════════════════════════════════════════ -# AMPERE PACKET DEFINITIONS (8-byte aligned) +# PACKET HEADER # ═══════════════════════════════════════════════════════════════════════════════ -# Lookup table for extracting sample bytes from 32-byte packet -LOOKUP_8B = [0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] - -class PMAHeaderAmpere8B(PacketType): +class PMAHeader(PacketType): num_bytes = bits[4:0] # number of sample bytes in this packet tpc_id_lo = bits[15:8] # TPC identifier low 8 bits tpc_id_hi = bits[27:25] # TPC identifier high 3 bits @@ -52,33 +63,57 @@ class PMAHeaderAmpere8B(PacketType): @property def tpc_id(self) -> int: return self.tpc_id_lo | (self.tpc_id_hi << 8) +# ═══════════════════════════════════════════════════════════════════════════════ +# 8-BYTE SAMPLE FORMAT (Ampere/Ada/Hopper) +# ═══════════════════════════════════════════════════════════════════════════════ + class PMASampleAmpere8B(PacketType): - pc_raw = bits[44:0] # raw PC value (actual PC = pc_raw << 4) - stall_lo = bits[47:45] # stall key low 3 bits - stall_hi = bits[49:48] # stall key high 2 bits - wave_id = bits[55:50] # warp/wave identifier (0-63) - active = bits[62:62] # active flag (warp was executing) + pc_raw = bits[44:0] # raw PC value (pc_offset = pc_raw << 4) + stall_key = bits[49:45] # stall reason key + wave_id = bits[55:50] # warp/wave identifier + active = bits[62:62] # 1 if warp was executing, 0 if scheduled but not issued @property def pc_offset(self) -> int: return self.pc_raw << 4 @property - def stall_key(self) -> int: return self.stall_lo | (self.stall_hi << 3) - @property - def stall_reason(self) -> StallReason: return STALL_KEY_MAP.get(self.stall_key, StallReason.OTHER) + def stall_reason(self) -> StallReason: return STALL_KEY_MAP_AMPERE.get(self.stall_key, StallReason.OTHER) + +# ═══════════════════════════════════════════════════════════════════════════════ +# 9-BYTE SAMPLE FORMAT (Blackwell+) +# ═══════════════════════════════════════════════════════════════════════════════ + +class PMASampleBlackwell9B(PacketType): + stall_key = bits[5:0] # stall reason key + pc_raw = bits[60:8] # raw PC value (pc_offset = pc_raw << 4) + wave_hi = bits[7:6] # wave_id high 2 bits + wave_lo = bits[71:68] # wave_id low 4 bits + active = bits[67:67] # 1 if warp was executing, 0 if scheduled but not issued + @property + def pc_offset(self) -> int: return self.pc_raw << 4 + @property + def stall_reason(self) -> StallReason: return STALL_KEY_MAP_BLACKWELL.get(self.stall_key, StallReason.OTHER) + @property + def wave_id(self) -> int: return (self.wave_hi << 4) | self.wave_lo + +PMASample = PMASampleAmpere8B|PMASampleBlackwell9B + +def decode(data: bytes, sm_version: int = 0x800) -> Iterator[tuple[PMASample, int]]: + use_9byte = sm_version >= 0xa04 + record_size = 9 if use_9byte else 8 + sample_cls = PMASampleBlackwell9B if use_9byte else PMASampleAmpere8B -def decode(data: bytes) -> Iterator[tuple[PMASampleAmpere8B, int]]: tpc_state: dict[int, list[int]] = collections.defaultdict(list) for pkt_idx in range(len(data) // 32): pkt = data[pkt_idx * 32:(pkt_idx + 1) * 32] - hdr = PMAHeaderAmpere8B.from_raw(int.from_bytes(pkt[4:8], 'little')) + hdr = PMAHeader.from_raw(int.from_bytes(pkt[4:8], 'little')) if hdr.dropped: tpc_state[hdr.tpc_id].clear() for i in range(hdr.num_bytes): - tpc_state[hdr.tpc_id].append(pkt[LOOKUP_8B[i]]) + tpc_state[hdr.tpc_id].append(pkt[LOOKUP_28B[i]]) - while len(tpc_state[hdr.tpc_id]) >= 8: - yield PMASampleAmpere8B.from_raw(int.from_bytes(bytes(tpc_state[hdr.tpc_id][:8]), 'little')), hdr.tpc_id - del tpc_state[hdr.tpc_id][:8] + while len(tpc_state[hdr.tpc_id]) >= record_size: + yield sample_cls.from_raw(int.from_bytes(bytes(tpc_state[hdr.tpc_id][:record_size]), 'little')), hdr.tpc_id + del tpc_state[hdr.tpc_id][:record_size] # ═══════════════════════════════════════════════════════════════════════════════ # CLI @@ -90,11 +125,11 @@ STALL_COLORS = { StallReason.PIPE_BUSY: "yellow", StallReason.MEMORY_THROTTLE: "RED", StallReason.OTHER: "white", } -def decode_tpc_id(tpc_id: int) -> tuple[int, int, int]: +def decode_tpc_id(tpc_id:int) -> tuple[int, int, int]: # NOTE: valid only for ops_nv, cuda encoding is different return (tpc_id >> 5, (tpc_id >> 1) & 0xf, tpc_id & 1) -def print_samples(samples: list[tuple[PMASampleAmpere8B, int]]) -> None: +def print_samples(samples:list[tuple[PMASample, int]]) -> None: if not samples: return base_pc = min(s.pc_offset for s, _ in samples) for s, tpc_id in samples: @@ -102,40 +137,58 @@ def print_samples(samples: list[tuple[PMASampleAmpere8B, int]]) -> None: stall_str = colored(f"{s.stall_reason.name:17}", STALL_COLORS.get(s.stall_reason, "white")) print(f"pc=0x{s.pc_offset - base_pc:06x} {stall_str} ev={s.stall_key:2d} active={s.active} wave={s.wave_id:2d} gpc={gpc} tpc={tpc} sm={sm}") -def print_packets(data: bytes) -> None: +def print_packets(data:bytes, sm_version:int=0x800) -> None: + record_size = 9 if sm_version >= 0x890 else 8 + tpc_state: dict[int, list[int]] = collections.defaultdict(list) for i in range(len(data) // 32): pkt = data[i * 32:(i + 1) * 32] - hdr = PMAHeaderAmpere8B.from_raw(int.from_bytes(pkt[4:8], 'little')) - print(f"Pkt {i:3d}: tpc={hdr.tpc_id} bytes={hdr.num_bytes} drop={hdr.dropped} | {pkt.hex()}") + hdr = PMAHeader.from_raw(int.from_bytes(pkt[4:8], 'little')) + if hdr.dropped: tpc_state[hdr.tpc_id].clear() + for j in range(hdr.num_bytes): tpc_state[hdr.tpc_id].append(pkt[LOOKUP_28B[j]]) + # Show complete records extracted from this packet + records = [] + while len(tpc_state[hdr.tpc_id]) >= record_size: + records.append(bytes(tpc_state[hdr.tpc_id][:record_size]).hex()) + del tpc_state[hdr.tpc_id][:record_size] + leftover = len(tpc_state[hdr.tpc_id]) + print(f"Pkt {i:3d}: tpc={hdr.tpc_id:4d} n={hdr.num_bytes:2d} drop={hdr.dropped} left={leftover} | {' '.join(records)}") -def print_aggregated(samples: list[tuple[PMASampleAmpere8B, int]]) -> None: +def print_aggregated(samples:list[tuple[PMASample, int]]) -> None: if not samples: return base_pc = min(s.pc_offset for s, _ in samples) - counter: collections.Counter[tuple[int, int]] = collections.Counter((s.pc_offset, s.stall_key) for s, _ in samples) + counter: collections.Counter[tuple[int, StallReason]] = collections.Counter((s.pc_offset, s.stall_reason) for s, _ in samples) print(f"\nAggregated samples (base_pc=0x{base_pc:x}):") - for (pc, key), cnt in sorted(counter.items()): - reason = STALL_KEY_MAP.get(key, StallReason.OTHER) + for (pc, reason), cnt in sorted(counter.items()): stall_str = colored(f"{reason.name:17}", STALL_COLORS.get(reason, "white")) - print(f" pc=0x{pc - base_pc:06x} {stall_str} ev={key:2d} samples={cnt:4d}") + print(f" pc=0x{pc - base_pc:06x} {stall_str} samples={cnt:4d}") if __name__ == "__main__": import sys, pickle if len(sys.argv) < 2: - print(__doc__) + print("Usage: python decode.py [--raw] [--sm=0xNNN]") sys.exit(1) + # Parse --sm=0xNNN argument + sm_version = 0x800 # default to Ampere + for arg in sys.argv: + if arg.startswith("--sm="): + sm_version = int(arg[5:], 0) + with open(sys.argv[1], "rb") as f: data = pickle.load(f) if isinstance(data, dict): dumps = list(enumerate(data["pma_raw_dumps"])) else: dumps = [(i, e.blob) for i, e in enumerate(e for e in data if type(e).__name__ == "ProfilePMAEvent")] + record_size = 9 if sm_version >= 0x890 else 8 + print(f"SM version: 0x{sm_version:x}, using {record_size}-byte records") + for dump_idx, raw in dumps: print(f"\n{'='*60}\nDump {dump_idx} ({len(raw)} bytes, {len(raw)//32} packets)\n{'='*60}") - if "--raw" in sys.argv: print_packets(raw) + if "--raw" in sys.argv: print_packets(raw, sm_version) else: - samples = list(decode(raw)) + samples = list(decode(raw, sm_version)) print(f"\nDecoded {len(samples)} samples:") print_samples(samples) print_aggregated(samples) diff --git a/extra/nv_pma/test/test_nvprof.py b/extra/nv_pma/test/test_nvprof.py index 9ebc49e8bd..6ad4b9bb91 100644 --- a/extra/nv_pma/test/test_nvprof.py +++ b/extra/nv_pma/test/test_nvprof.py @@ -6,13 +6,17 @@ from extra.nv_pma.decode import decode from tinygrad.helpers import DEBUG EXAMPLES_DIR = Path(__file__).parent.parent / "examples" +EXAMPLES_5090_DIR = Path(__file__).parent.parent / "examples_5090" -def decode_and_aggregate(raw_dumps: list[bytes]) -> Counter[tuple[int, int]]: - """Decode all PMA buffers and aggregate by (relative_pc, stall_reason).""" - all_samples = [s for raw in raw_dumps for s, _ in decode(raw)] - if not all_samples: return Counter() - base_pc = min(s.pc_offset for s in all_samples) - return Counter((s.pc_offset - base_pc, int(s.stall_reason)) for s in all_samples) +def decode_and_aggregate(raw_dumps: list[bytes], sm_version: int = 0x800) -> Counter[tuple[int, int]]: + """Decode all PMA buffers and aggregate by (relative_pc, stall_reason). Each dump is normalized separately.""" + result: Counter[tuple[int, int]] = Counter() + for raw in raw_dumps: + samples = [s for s, _ in decode(raw, sm_version)] + if not samples: continue + base_pc = min(s.pc_offset for s in samples) + result += Counter((s.pc_offset - base_pc, int(s.stall_reason)) for s in samples) + return result def cupti_to_counter(cupti_records: list[dict]) -> Counter[tuple[int, int]]: """Convert CUPTI records to Counter[(pcOffset, stallReason)].""" @@ -22,8 +26,8 @@ def cupti_to_counter(cupti_records: list[dict]) -> Counter[tuple[int, int]]: return counter class TestNVProf(unittest.TestCase): - def _test_example(self, name: str): - pkl_file = EXAMPLES_DIR / f"{name}.pkl" + def _test_example(self, name: str, sm_version: int = 0x800, examples_dir: Path = EXAMPLES_DIR): + pkl_file = examples_dir / f"{name}.pkl" if not pkl_file.exists(): self.skipTest(f"Example data not found: {pkl_file}. Run collect.py first.") @@ -31,7 +35,7 @@ class TestNVProf(unittest.TestCase): data = pickle.load(f) self.assertEqual(data["test_name"], name) - pma_agg = decode_and_aggregate(data["pma_raw_dumps"]) + pma_agg = decode_and_aggregate(data["pma_raw_dumps"], sm_version) cupti_agg = cupti_to_counter(data["cupti_pc_samples"]) if DEBUG >= 2: @@ -45,6 +49,7 @@ class TestNVProf(unittest.TestCase): self.assertEqual(pma_agg, cupti_agg, f"PMA: {dict(pma_agg)}\nCUPTI: {dict(cupti_agg)}") + # Ampere tests (8-byte format) def test_decode_test_plus(self): self._test_example("test_plus") def test_decode_test_reduce_sum(self): self._test_example("test_reduce_sum") def test_decode_test_broadcast(self): self._test_example("test_broadcast") @@ -54,5 +59,18 @@ class TestNVProf(unittest.TestCase): def test_decode_test_conv2d(self): self._test_example("test_conv2d") def test_decode_test_large_matmul(self): self._test_example("test_large_matmul") + # Blackwell/5090 tests (9-byte format) + def test_5090_test_plus(self): self._test_example("test_plus", 0xa04, EXAMPLES_5090_DIR) + def test_5090_test_plus_big(self): self._test_example("test_plus_big", 0xa04, EXAMPLES_5090_DIR) + def test_5090_test_broadcast(self): self._test_example("test_broadcast", 0xa04, EXAMPLES_5090_DIR) + def test_5090_test_matmul(self): self._test_example("test_matmul", 0xa04, EXAMPLES_5090_DIR) + def test_5090_test_large_matmul(self): self._test_example("test_large_matmul", 0xa04, EXAMPLES_5090_DIR) + def test_5090_test_reduce_sum(self): self._test_example("test_reduce_sum", 0xa04, EXAMPLES_5090_DIR) + def test_5090_test_reduce_max(self): self._test_example("test_reduce_max", 0xa04, EXAMPLES_5090_DIR) + def test_5090_test_elementwise_chain(self): self._test_example("test_elementwise_chain", 0xa04, EXAMPLES_5090_DIR) + def test_5090_test_conv2d(self): self._test_example("test_conv2d", 0xa04, EXAMPLES_5090_DIR) + def test_5090_test_exp(self): self._test_example("test_exp", 0xa04, EXAMPLES_5090_DIR) + def test_5090_test_softmax(self): self._test_example("test_softmax", 0xa04, EXAMPLES_5090_DIR) + if __name__ == "__main__": unittest.main()