From 9f8afb518c375a5dc272fbb2cb0d95d41d3639b8 Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Mon, 16 Feb 2026 16:45:06 +0300 Subject: [PATCH] viz: sdma gb/s in graph (#14798) * viz: sdma gb/s in graph * f --- test/null/test_viz.py | 14 ++++++++++++++ tinygrad/device.py | 4 ++-- tinygrad/runtime/graph/hcq.py | 4 ++-- 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/test/null/test_viz.py b/test/null/test_viz.py index fb5ca51990..a89cf56b32 100644 --- a/test/null/test_viz.py +++ b/test/null/test_viz.py @@ -443,6 +443,20 @@ class TestVizProfiler(BaseTestViz): self.assertEqual(graph_events[0]['st'], nv_events[0]['st']) self.assertEqual(graph_events[0]['st']+graph_events[0]['dur'], sdma_events[0]['st']+sdma_events[0]['dur']) + def test_graph_copy_bandwidth(self): + sz = 256*1024*1024 + dur = 10_000 + prof = [ProfileDeviceEvent(device='NV', tdiff=decimal.Decimal(-1000)), + ProfileDeviceEvent(device='NV:1:SDMA:0', tdiff=decimal.Decimal(-50)), + ProfileGraphEvent(ents=[ProfileGraphEntry(device='NV:1:SDMA:0', name=TracingKey("NV -> NV:1", ret=sz), st_id=0, en_id=1)], + deps=[[]], + sigs=[decimal.Decimal(1004), decimal.Decimal(1004+dur)])] + + j = load_profile(prof) + sdma_events = j['layout']['NV:1:SDMA:0']['events'] + gbs = sz/(dur*1e-6)*1e-9 + self.assertEqual(sdma_events[0]['fmt'], f"{gbs:.0f} GB/s") + def test_block_ordering(self): prof = [ProfileDeviceEvent(device='NV', tdiff=decimal.Decimal(-1000)), ProfileDeviceEvent(device='NV:1', tdiff=decimal.Decimal(-500)), diff --git a/tinygrad/device.py b/tinygrad/device.py index 9861f886d0..88238cfaf2 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -6,7 +6,7 @@ import importlib, inspect, functools, pathlib, os, platform, contextlib, sys, re from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, PROFILE, temp, colored from tinygrad.helpers import Context, CCACHE, ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE, cpu_events, ProfileEvent, ProfilePointEvent, dedup, ContextVar from tinygrad.helpers import unwrap_class_type, suppress_finalizing, select_first_inited, VIZ, CPU_LLVM, CPU_LVP, NV_PTX, CUDA_PTX, NV_NAK -from tinygrad.helpers import EMULATED_DTYPES +from tinygrad.helpers import EMULATED_DTYPES, TracingKey from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes, _to_np_dtype if TYPE_CHECKING: from tinygrad.renderer import Renderer @@ -62,7 +62,7 @@ class ProfileDeviceEvent(ProfileEvent): device:str; tdiff:decimal.Decimal=decima class ProfileProgramEvent(ProfileEvent): device:str; name:str; lib:bytes|None; base:int|None; tag:int|None=None # noqa: E702 @dataclass(frozen=True) -class ProfileGraphEntry: device:str; name:str; st_id:int; en_id:int # noqa: E702 +class ProfileGraphEntry: device:str; name:str|TracingKey; st_id:int; en_id:int # noqa: E702 @dataclass(frozen=True) class ProfileGraphEvent(ProfileEvent): ents:list[ProfileGraphEntry]; deps:list[list[int]]; sigs:list[decimal.Decimal] # noqa: E702 diff --git a/tinygrad/runtime/graph/hcq.py b/tinygrad/runtime/graph/hcq.py index 718e55f67b..9d6619b163 100644 --- a/tinygrad/runtime/graph/hcq.py +++ b/tinygrad/runtime/graph/hcq.py @@ -1,6 +1,6 @@ import collections, time from typing import Any, cast -from tinygrad.helpers import round_up, PROFILE, ALL2ALL, merge_dicts, getenv, dedup, suppress_finalizing +from tinygrad.helpers import round_up, PROFILE, ALL2ALL, merge_dicts, getenv, dedup, suppress_finalizing, TracingKey from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQSignal, HCQBuffer, HWQueue, HCQArgsState, BumpAllocator, MMIOInterface from tinygrad.device import Buffer, BufferSpec, Compiled, Device, ProfileGraphEntry, ProfileGraphEvent from tinygrad.dtype import dtypes @@ -129,7 +129,7 @@ class HCQGraph(MultiGraphRunner): sig_st = prev_ji * 2 + 1 if len(opt_deps) == 0 and (prev_ji:=last_j[enqueue_queue]) is not None else j * 2 # Description based on the command. - prof_ji_desc = ji.prg._prg.name if is_exec_prg else f"{ji.bufs[1].device} -> {ji.bufs[0].device}" # type: ignore + prof_ji_desc = ji.prg._prg.name if is_exec_prg else TracingKey(f"{ji.bufs[1].device} -> {ji.bufs[0].device}", ret=ji.bufs[0].nbytes) # type: ignore prof_name = f"{enqueue_dev.device}:SDMA:{queue_idx}" if not is_exec_prg else enqueue_dev.device self.prof_graph_entries.append(ProfileGraphEntry(prof_name, prof_ji_desc, sig_st, j * 2 + 1))