mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
record GraphEvents in metal graph (#11145)
* record GraphEvents in metal graph * add TestProfiler.test_graph, revert old stuff * move profile capture to MetalGraph * comment * don't double record graph command buffers * wait_check * explicit delete
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
import unittest, struct, contextlib, statistics, time
|
||||
import unittest, struct, contextlib, statistics, time, gc
|
||||
from tinygrad import Device, Tensor, dtypes, TinyJit
|
||||
from tinygrad.helpers import CI, getenv, Context, ProfileRangeEvent, cpu_profile, cpu_events
|
||||
from tinygrad.device import Buffer, BufferSpec, Compiled, ProfileDeviceEvent, ProfileGraphEvent
|
||||
@@ -181,5 +181,21 @@ class TestProfiler(unittest.TestCase):
|
||||
# record start/end time up to exit (error or success)
|
||||
self.assertGreater(range_events[0].en-range_events[0].st, range_events[1].en-range_events[1].st)
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].graph is not None, "graph support required")
|
||||
def test_graph(self):
|
||||
from test.test_graph import helper_alloc_rawbuffer, helper_exec_op, helper_test_graphs
|
||||
device = TestProfiler.d0.device
|
||||
bufs = [helper_alloc_rawbuffer(device, fill=True) for _ in range(5)]
|
||||
graphs = [[helper_exec_op(device, bufs[0], [bufs[1], bufs[2]]), helper_exec_op(device, bufs[0], [bufs[3], bufs[4]]),]]
|
||||
with helper_collect_profile(dev:=TestProfiler.d0) as profile:
|
||||
helper_test_graphs(dev.graph, graphs, runs:=2)
|
||||
# NOTE: explicitly trigger deletion of all graphs
|
||||
graphs.clear()
|
||||
gc.collect()
|
||||
graphs = [e for e in profile if isinstance(e, ProfileGraphEvent)]
|
||||
self.assertEqual(len(graphs), runs)
|
||||
for ge in graphs:
|
||||
self.assertEqual(len(ge.ents), len(graphs))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from typing import Any, cast
|
||||
import ctypes, re
|
||||
import ctypes, re, decimal
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.helpers import dedup, getenv, merge_dicts
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.helpers import dedup, getenv, merge_dicts, PROFILE
|
||||
from tinygrad.device import Buffer, ProfileGraphEntry, ProfileGraphEvent
|
||||
from tinygrad.engine.realize import ExecItem, CompiledRunner
|
||||
from tinygrad.engine.jit import GraphRunner, GraphException
|
||||
from tinygrad.uop.ops import Variable
|
||||
@@ -63,6 +63,8 @@ class MetalGraph(GraphRunner):
|
||||
|
||||
def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[Variable, int], wait=False) -> float|None:
|
||||
if self.command_buffer is not None and self.command_buffer in self.dev.mtl_buffers_in_flight: wait_check(self.command_buffer)
|
||||
# NOTE: old command buffer may not be inflight anymore
|
||||
if self.command_buffer is not None and PROFILE: self.collect_timestamps()
|
||||
|
||||
all_resources = dedup(self.all_resources + [input_rawbuffers[input_idx]._buf.buf for input_idx in self.input_replace.values()])
|
||||
for (j,i),input_idx in self.input_replace.items():
|
||||
@@ -100,3 +102,15 @@ class MetalGraph(GraphRunner):
|
||||
wait_check(command_buffer)
|
||||
return cmdbuf_en_time(command_buffer) - cmdbuf_st_time(command_buffer)
|
||||
return None
|
||||
|
||||
def collect_timestamps(self):
|
||||
# create a graph event and evenly space each program
|
||||
st, en = decimal.Decimal(cmdbuf_st_time(self.command_buffer)) * 1000000, decimal.Decimal(cmdbuf_en_time(self.command_buffer)) * 1000000
|
||||
ents = [ProfileGraphEntry(self.device, cast(CompiledRunner, ji.prg)._prg.name, i, i+1, is_copy=False) for i,ji in enumerate(self.jit_cache)]
|
||||
step = (en-st)/len(ents)
|
||||
self.dev.profile_events += [ProfileGraphEvent(ents, [], [st+step*i for i in range(len(ents)+1)])]
|
||||
|
||||
def __del__(self):
|
||||
if PROFILE and self.command_buffer is not None:
|
||||
wait_check(self.command_buffer)
|
||||
self.collect_timestamps()
|
||||
|
||||
@@ -83,7 +83,8 @@ class MetalDevice(Compiled):
|
||||
for cbuf in self.mtl_buffers_in_flight:
|
||||
wait_check(cbuf)
|
||||
st, en = decimal.Decimal(cmdbuf_st_time(cbuf)) * 1000000, decimal.Decimal(cmdbuf_en_time(cbuf)) * 1000000
|
||||
if PROFILE and (lb:=cmdbuf_label(cbuf)) is not None:
|
||||
# NOTE: command buffers from MetalGraph are not profiled here
|
||||
if PROFILE and (lb:=cmdbuf_label(cbuf)) is not None and not lb.startswith("batched"):
|
||||
Compiled.profile_events += [ProfileRangeEvent(self.device, lb, st, en, is_copy=lb.startswith("COPY"))]
|
||||
self.mtl_buffers_in_flight.clear()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user