diff --git a/extra/sqtt/roc.py b/extra/sqtt/roc.py index 9d4317c7b0..41415a8caf 100644 --- a/extra/sqtt/roc.py +++ b/extra/sqtt/roc.py @@ -41,6 +41,7 @@ class WaveExec: wave_id:int cu:int simd:int + se:int begin_time:int end_time:int insts:list[InstExec] @@ -78,7 +79,8 @@ class _ROCParseCtx: if DEBUG >= 8: print(inst_execs[-1]) if ev.instructions_size > 0: - self.inst_execs.setdefault(unwrap(self.active_kern), []).append(WaveExec(ev.wave_id, ev.cu, ev.simd, ev.begin_time, ev.end_time, inst_execs)) + self.inst_execs.setdefault(unwrap(self.active_kern), []).append(WaveExec(ev.wave_id, ev.cu, ev.simd, unwrap(self.active_se), ev.begin_time, + ev.end_time, inst_execs)) def decode(profile:list[ProfileEvent]) -> _ROCParseCtx: dev_events:dict[str, ProfileDeviceEvent] = {} diff --git a/extra/sqtt/test_timing.py b/extra/sqtt/test_timing.py index af72800658..a5725643d2 100644 --- a/extra/sqtt/test_timing.py +++ b/extra/sqtt/test_timing.py @@ -9,6 +9,7 @@ import unittest import sys, contextlib from tinygrad import Tensor from tinygrad.dtype import dtypes +from tinygrad.helpers import getenv from tinygrad.renderer import ProgramSpec from tinygrad.uop.ops import UOp, Ops, KernelInfo, AddrSpace from tinygrad.engine.realize import CompiledRunner @@ -120,5 +121,17 @@ class TestTiming(unittest.TestCase): for e in wave.insts: print(f"{e.inst} {e.dur=} {e.stall=}") + def test_wave_sched(self): + num_waves = getenv("NUM_WAVES", 16) + num_wgps = getenv("NUM_WGPS", 2) + num_vgpr = getenv("NUM_VGPR", 256) + with save_sqtt() as sqtt: + # 1 cycle decode, no stall + asm_kernel([f"v_mov_b32_e32 v{i} {i}" for i in range(num_vgpr)], l=32*num_waves, g=num_wgps).realize() + waves = list(sqtt.values())[0] + print(len(waves), "waves decoded") + for w in waves: + print(f"{w.wave_id:<2} {w.simd=} {w.cu=} {w.se=} @ clk {w.begin_time}") + if __name__ == "__main__": unittest.main() diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 8a943c84b3..fa6ba95683 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -232,8 +232,8 @@ def load_sqtt(profile:list[ProfileEvent]) -> None: for i,e in enumerate(w.insts): rows.append((e.inst, e.time, max(0, e.time-prev_instr), e.dur, e.stall, str(e.typ).split("_")[-1])) prev_instr = max(prev_instr, e.time + e.dur) - summary = [{"label":"Total Cycles", "value":w.end_time-w.begin_time}, {"label":"CU", "value":w.cu}, - {"label":"SIMD", "value":w.simd}] + summary = [{"label":"Total Cycles", "value":w.end_time-w.begin_time}, {"label":"SIMD", "value":w.simd}, {"label":"CU", "value":w.cu}, + {"label":"SE", "value":w.se}] steps.append({"name":f"Wave {w.wave_id}", "depth":1, "query":f"/render?ctx={len(ctxs)}&step={len(steps)}&fmt=counters", "data":{"rows":rows, "cols":["Instruction", "Clk", "Idle", "Duration", "Stall", "Type"], "summary":summary}}) ctxs.append({"name":"Counters", "steps":steps})