record GraphEvents in metal graph (#11145)

* record GraphEvents in metal graph

* add TestProfiler.test_graph, revert old stuff

* move profile capture to MetalGraph

* comment

* don't double record graph command buffers

* wait_check

* explicit delete
This commit is contained in:
qazal
2025-07-10 21:32:06 +03:00
committed by GitHub
parent 8ce3d5906b
commit bde80c0cdf
3 changed files with 36 additions and 5 deletions

View File

@@ -1,4 +1,4 @@
import unittest, struct, contextlib, statistics, time
import unittest, struct, contextlib, statistics, time, gc
from tinygrad import Device, Tensor, dtypes, TinyJit
from tinygrad.helpers import CI, getenv, Context, ProfileRangeEvent, cpu_profile, cpu_events
from tinygrad.device import Buffer, BufferSpec, Compiled, ProfileDeviceEvent, ProfileGraphEvent
@@ -181,5 +181,21 @@ class TestProfiler(unittest.TestCase):
# record start/end time up to exit (error or success)
self.assertGreater(range_events[0].en-range_events[0].st, range_events[1].en-range_events[1].st)
@unittest.skipUnless(Device[Device.DEFAULT].graph is not None, "graph support required")
def test_graph(self):
from test.test_graph import helper_alloc_rawbuffer, helper_exec_op, helper_test_graphs
device = TestProfiler.d0.device
bufs = [helper_alloc_rawbuffer(device, fill=True) for _ in range(5)]
graphs = [[helper_exec_op(device, bufs[0], [bufs[1], bufs[2]]), helper_exec_op(device, bufs[0], [bufs[3], bufs[4]]),]]
with helper_collect_profile(dev:=TestProfiler.d0) as profile:
helper_test_graphs(dev.graph, graphs, runs:=2)
# NOTE: explicitly trigger deletion of all graphs
graphs.clear()
gc.collect()
graphs = [e for e in profile if isinstance(e, ProfileGraphEvent)]
self.assertEqual(len(graphs), runs)
for ge in graphs:
self.assertEqual(len(ge.ents), len(graphs))
if __name__ == "__main__":
unittest.main()

View File

@@ -1,8 +1,8 @@
from typing import Any, cast
import ctypes, re
import ctypes, re, decimal
from tinygrad.dtype import dtypes
from tinygrad.helpers import dedup, getenv, merge_dicts
from tinygrad.device import Buffer
from tinygrad.helpers import dedup, getenv, merge_dicts, PROFILE
from tinygrad.device import Buffer, ProfileGraphEntry, ProfileGraphEvent
from tinygrad.engine.realize import ExecItem, CompiledRunner
from tinygrad.engine.jit import GraphRunner, GraphException
from tinygrad.uop.ops import Variable
@@ -63,6 +63,8 @@ class MetalGraph(GraphRunner):
def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[Variable, int], wait=False) -> float|None:
if self.command_buffer is not None and self.command_buffer in self.dev.mtl_buffers_in_flight: wait_check(self.command_buffer)
# NOTE: old command buffer may not be inflight anymore
if self.command_buffer is not None and PROFILE: self.collect_timestamps()
all_resources = dedup(self.all_resources + [input_rawbuffers[input_idx]._buf.buf for input_idx in self.input_replace.values()])
for (j,i),input_idx in self.input_replace.items():
@@ -100,3 +102,15 @@ class MetalGraph(GraphRunner):
wait_check(command_buffer)
return cmdbuf_en_time(command_buffer) - cmdbuf_st_time(command_buffer)
return None
def collect_timestamps(self):
# create a graph event and evenly space each program
st, en = decimal.Decimal(cmdbuf_st_time(self.command_buffer)) * 1000000, decimal.Decimal(cmdbuf_en_time(self.command_buffer)) * 1000000
ents = [ProfileGraphEntry(self.device, cast(CompiledRunner, ji.prg)._prg.name, i, i+1, is_copy=False) for i,ji in enumerate(self.jit_cache)]
step = (en-st)/len(ents)
self.dev.profile_events += [ProfileGraphEvent(ents, [], [st+step*i for i in range(len(ents)+1)])]
def __del__(self):
if PROFILE and self.command_buffer is not None:
wait_check(self.command_buffer)
self.collect_timestamps()

View File

@@ -83,7 +83,8 @@ class MetalDevice(Compiled):
for cbuf in self.mtl_buffers_in_flight:
wait_check(cbuf)
st, en = decimal.Decimal(cmdbuf_st_time(cbuf)) * 1000000, decimal.Decimal(cmdbuf_en_time(cbuf)) * 1000000
if PROFILE and (lb:=cmdbuf_label(cbuf)) is not None:
# NOTE: command buffers from MetalGraph are not profiled here
if PROFILE and (lb:=cmdbuf_label(cbuf)) is not None and not lb.startswith("batched"):
Compiled.profile_events += [ProfileRangeEvent(self.device, lb, st, en, is_copy=lb.startswith("COPY"))]
self.mtl_buffers_in_flight.clear()