diff --git a/test/test_hcq.py b/test/test_hcq.py index a3832f0b5b..a3aea0be5d 100644 --- a/test/test_hcq.py +++ b/test/test_hcq.py @@ -300,8 +300,6 @@ class TestHCQ(unittest.TestCase): # Test profile api def test_speed_exec_time(self): - TestHCQ.d0._prof_setup() - sig_st, sig_en = TestHCQ.d0.signal_t(), TestHCQ.d0.signal_t() TestHCQ.d0.hw_compute_queue_t().timestamp(sig_st) \ .exec(TestHCQ.runner._prg, TestHCQ.kernargs_ba_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size) \ @@ -311,7 +309,7 @@ class TestHCQ(unittest.TestCase): TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value) TestHCQ.d0.timeline_value += 1 - et = TestHCQ.d0._gpu2cpu_time(sig_en.timestamp, True) - TestHCQ.d0._gpu2cpu_time(sig_st.timestamp, True) + et = float(sig_en.timestamp - sig_st.timestamp) print(f"exec kernel time: {et:.2f} us") assert 0.1 <= et <= (7000 if CI else 100) @@ -319,8 +317,6 @@ class TestHCQ(unittest.TestCase): def test_speed_copy_bandwidth(self): if TestHCQ.d0.hw_copy_queue_t is None: self.skipTest("device does not support copy queue") - TestHCQ.d0._prof_setup() - # THEORY: the bandwidth is low here because it's only using one SDMA queue. I suspect it's more stable like this at least. SZ = 200_000_000 a = Buffer(Device.DEFAULT, SZ, dtypes.uint8, options=BufferSpec(nolru=True)).allocate() @@ -335,7 +331,7 @@ class TestHCQ(unittest.TestCase): TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value) TestHCQ.d0.timeline_value += 1 - et = TestHCQ.d0._gpu2cpu_time(sig_en.timestamp, True) - TestHCQ.d0._gpu2cpu_time(sig_st.timestamp, True) + et = float(sig_en.timestamp - sig_st.timestamp) et_ms = et / 1e3 gb_s = ((SZ / 1e9) / et_ms) * 1e3 @@ -348,8 +344,6 @@ class TestHCQ(unittest.TestCase): try: _ = Device[f"{Device.DEFAULT}:1"] except Exception: self.skipTest("no multidevice, test skipped") - TestHCQ.d0._prof_setup() - SZ = 200_000_000 b = Buffer(f"{Device.DEFAULT}:1", SZ, dtypes.uint8, options=BufferSpec(nolru=True)).allocate() a = Buffer(Device.DEFAULT, SZ, dtypes.uint8, options=BufferSpec(nolru=True)).allocate() @@ -364,7 +358,7 @@ class TestHCQ(unittest.TestCase): TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value) TestHCQ.d0.timeline_value += 1 - et = TestHCQ.d0._gpu2cpu_time(sig_en.timestamp, True) - TestHCQ.d0._gpu2cpu_time(sig_st.timestamp, True) + et = float(sig_en.timestamp - sig_st.timestamp) et_ms = et / 1e3 gb_s = ((SZ / 1e9) / et_ms) * 1e3 diff --git a/test/test_profiler.py b/test/test_profiler.py index f832e32514..71755d2afd 100644 --- a/test/test_profiler.py +++ b/test/test_profiler.py @@ -1,73 +1,30 @@ -import unittest, struct, contextlib, tempfile, pathlib, json, time, atexit, random +import unittest, struct, contextlib, statistics from tinygrad import Device, Tensor, dtypes, TinyJit from tinygrad.helpers import CI, getenv, Context -from tinygrad.device import Buffer, BufferSpec -from tinygrad.runtime.support.hcq import ProfileLogger, HCQCompiled +from tinygrad.device import Buffer, BufferSpec, Compiled, ProfileRangeEvent, ProfileDeviceEvent, ProfileGraphEvent +from tinygrad.runtime.support.hcq import HCQCompiled from tinygrad.engine.schedule import create_schedule from tinygrad.engine.realize import get_runner MOCKGPU = getenv("MOCKGPU") @contextlib.contextmanager -def helper_collect_profile(*devs, random_setup_delay=False): - ProfileLogger.mjson, ProfileLogger.actors = [], {} +def helper_collect_profile(*devs): + Compiled.profile_events = [] - if random_setup_delay: - devs = list(devs) - for dev in devs: dev.synchronize() - random.shuffle(devs) - for dev in devs: - dev._prof_setup() - time.sleep(random.randint(1, 1000) / 1000) - else: - for dev in devs: dev._prof_setup() - - profile_dict = {} - _, tmp = tempfile.mkstemp() - with Context(PROFILE=1, PROFILEPATH=tmp): - try: yield profile_dict + profile_list = [] + with Context(PROFILE=1): + try: yield profile_list finally: - for dev in devs: - dev.synchronize() - dev._prof_finalize() - atexit.unregister(dev._prof_finalize) + for dev in devs: dev.synchronize() + for dev in devs: dev._at_profile_finalize() + for x in Compiled.profile_events: profile_list.append(x) - for k,v in json.loads(pathlib.Path(tmp).read_text()).items(): profile_dict[k] = v - pathlib.Path(tmp).unlink() - -def helper_profile_filter_node(profile, **kwargs): - assert len(profile) > 0, "Empty profile" - assert 'traceEvents' in profile, "traceEvents should present" - return [x for x in profile['traceEvents'] if all(x.get(k, None) == v for k,v in kwargs.items())] - -def helper_profile_parse_pids(profile): - pids, tids = {}, {} - procs = helper_profile_filter_node(profile, name='process_name') - for proc in procs: pids[proc['pid']] = proc['args']['name'] - threads = helper_profile_filter_node(profile, name='thread_name') - for th in threads: tids[th['tid']] = th['args']['name'] - return pids, tids - -def helper_profile_parse_deps(profile): - deps = [] - for s in helper_profile_filter_node(profile, ph="s"): - f = helper_profile_filter_node(profile, ph="f", id=s['id'])[0] - - starts, ends = [], [] - for x in helper_profile_filter_node(profile, ph="X"): - if s['pid'] == x['pid'] and s['tid'] == x['tid'] and x['ts'] <= s['ts'] <= x['ts'] + x['dur']: starts.append(x) - if f['pid'] == x['pid'] and f['tid'] == x['tid'] and x['ts'] <= f['ts'] <= x['ts'] + x['dur']: ends.append(x) - - assert len(starts) == 1 and len(ends) == 1, "more than one start and end possible, valid?" - deps.append((s, f, starts[0], ends[0])) - return deps - -def helper_validate_node(node, duration_s=10, ts_age_s=30, profile=None, pid_name=None, tid_name=None): - pids, tids = helper_profile_parse_pids(profile) - assert abs(node['ts'] - time.perf_counter_ns() / 1e3) < ts_age_s * 1e6, "timestimp is not in 30s range" - assert 0 < node['dur'] < duration_s * 1e6, "duration is not in 10s range" - assert pid_name is None or pids[node['pid']] == pid_name - assert tid_name is None or tids[node['tid']] == tid_name +def helper_profile_filter_device(profile, device:str): + assert any(getattr(x, "device", None) == device and isinstance(x, ProfileDeviceEvent) for x in profile), f"device {device} is not registred" + dev_events = [x for x in profile if getattr(x, "device", None) == device and isinstance(x, ProfileDeviceEvent)] + assert len(dev_events) == 1, "only one device registration event is expected" + return [x for x in profile if getattr(x, "device", None) == device], dev_events[0] @unittest.skipUnless(issubclass(type(Device[Device.DEFAULT]), HCQCompiled), "HCQ device required to run") class TestProfiler(unittest.TestCase): @@ -90,8 +47,11 @@ class TestProfiler(unittest.TestCase): with helper_collect_profile(TestProfiler.d0) as profile: TestProfiler.runner([TestProfiler.b.lazydata.buffer, TestProfiler.a.lazydata.buffer], var_vals={}) - kernel_node = helper_profile_filter_node(profile, name=runner_name)[0] - helper_validate_node(kernel_node, profile=profile, pid_name=Device.DEFAULT, tid_name="COMPUTE") + profile, _ = helper_profile_filter_device(profile, TestProfiler.d0.device) + kernel_runs = [x for x in profile if isinstance(x, ProfileRangeEvent)] + assert len(kernel_runs) == 1, "one kernel run is expected" + assert kernel_runs[0].name == runner_name, "kernel name is not correct" + assert not kernel_runs[0].is_copy, "kernel should not be copy" def test_profile_copyin(self): buf1 = Buffer(Device.DEFAULT, 2, dtypes.float, options=BufferSpec(nolru=True)).ensure_allocated() @@ -99,8 +59,10 @@ class TestProfiler(unittest.TestCase): with helper_collect_profile(TestProfiler.d0) as profile: buf1.copyin(memoryview(bytearray(struct.pack("ff", 0, 1)))) - copyin_node = helper_profile_filter_node(profile, name=f"CPU -> {Device.DEFAULT}")[0] - helper_validate_node(copyin_node, profile=profile, pid_name=Device.DEFAULT, tid_name="DMA") + profile, _ = helper_profile_filter_device(profile, TestProfiler.d0.device) + kernel_runs = [x for x in profile if isinstance(x, ProfileRangeEvent)] + assert len(kernel_runs) == 1, "one kernel run is expected" + assert kernel_runs[0].is_copy, "kernel should not be copy" def test_profile_multiops(self): runner_name = TestProfiler.runner._prg.name @@ -111,19 +73,19 @@ class TestProfiler(unittest.TestCase): TestProfiler.runner([buf1, TestProfiler.a.lazydata.buffer], var_vals={}) buf1.as_buffer() - copyin_node = helper_profile_filter_node(profile, name=f"CPU -> {Device.DEFAULT}")[0] - helper_validate_node(copyin_node, profile=profile, pid_name=Device.DEFAULT, tid_name="DMA") + profile, _ = helper_profile_filter_device(profile, TestProfiler.d0.device) + evs = [x for x in profile if isinstance(x, ProfileRangeEvent)] - kernel_node = helper_profile_filter_node(profile, name=runner_name)[0] - helper_validate_node(kernel_node, profile=profile, pid_name=Device.DEFAULT, tid_name="COMPUTE") + assert len(evs) == 3, "two kernel runs are expected" + assert evs[0].is_copy, "kernel should be copy" + assert evs[1].name == runner_name, "kernel name is not correct" + assert not evs[1].is_copy, "kernel should not be copy" + assert evs[2].is_copy, "kernel should be copy" - copyout_node = helper_profile_filter_node(profile, name=f"{Device.DEFAULT} -> CPU")[0] - helper_validate_node(copyout_node, profile=profile, pid_name=Device.DEFAULT, tid_name="DMA") + for i in range(1, 3): + assert evs[i].st > evs[i-1].en, "timestamp not aranged" - assert copyin_node['ts'] + copyin_node['dur'] < kernel_node['ts'], "timestamp not aranged" - assert kernel_node['ts'] + kernel_node['dur'] < copyout_node['ts'], "timestamp not aranged" - - def test_profile_multidev_copyin(self): + def test_profile_multidev(self): d1 = Device[f"{Device.DEFAULT}:1"] buf1 = Buffer(Device.DEFAULT, 2, dtypes.float, options=BufferSpec(nolru=True)).ensure_allocated() buf2 = Buffer(f"{Device.DEFAULT}:1", 2, dtypes.float, options=BufferSpec(nolru=True)).ensure_allocated() @@ -132,25 +94,16 @@ class TestProfiler(unittest.TestCase): buf1.copyin(memoryview(bytearray(struct.pack("ff", 0, 1)))) buf2.copyin(memoryview(bytearray(struct.pack("ff", 0, 1)))) - copyin_node_1 = helper_profile_filter_node(profile, name=f"CPU -> {Device.DEFAULT}")[0] - helper_validate_node(copyin_node_1, profile=profile, pid_name=Device.DEFAULT, tid_name="DMA") + profile0, _ = helper_profile_filter_device(profile, TestProfiler.d0.device) + profile1, _ = helper_profile_filter_device(profile, d1.device) - copyin_node_2 = helper_profile_filter_node(profile, name=f"CPU -> {Device.DEFAULT}:1")[0] - helper_validate_node(copyin_node_2, profile=profile, pid_name=f"{Device.DEFAULT}:1", tid_name="DMA") - - def test_profile_multidev_transfer(self): - d1 = Device[f"{Device.DEFAULT}:1"] - a = Tensor.randn(1 << 20, device=Device.DEFAULT).realize() - with helper_collect_profile(TestProfiler.d0, d1) as profile: - y = a.to(f"{Device.DEFAULT}:1") - y.realize() - - transfer_node_1 = helper_profile_filter_node(profile, name=f"{Device.DEFAULT} -> {Device.DEFAULT}:1")[0] - helper_validate_node(transfer_node_1, profile=profile, pid_name=Device.DEFAULT, tid_name="DMA") - assert 80 < transfer_node_1['dur'] < (20000 if CI else 1400), f"Duration is not in the range: {transfer_node_1['dur']}" + for p in [profile0, profile1]: + evs = [x for x in p if isinstance(x, ProfileRangeEvent)] + assert len(evs) == 1, "one kernel runs are expected" + assert evs[0].is_copy, "kernel should be copy" @unittest.skipIf(MOCKGPU and Device.DEFAULT == "AMD", "AMD mockgpu with indirect buffers does not support queue wait interrupts") - def test_profile_deps(self): + def test_profile_graph(self): d1 = Device[f"{Device.DEFAULT}:1"] def f(a): @@ -163,59 +116,40 @@ class TestProfiler(unittest.TestCase): for _ in range(3): jf(a) del jf - deps = helper_profile_parse_deps(profile) - assert len(deps) == 1, "one dep is expected, one launch" + graph_evs = [x for x in profile if isinstance(x, ProfileGraphEvent)] - _, _, l, r = deps[0] - assert l['name'].find("->") == -1, "should be kernel" - assert r['name'] == f"{Device.DEFAULT} -> {Device.DEFAULT}:1", "should be copy" + _, _ = helper_profile_filter_device(profile, TestProfiler.d0.device) + _, _ = helper_profile_filter_device(profile, d1.device) - @unittest.skipIf(MOCKGPU and Device.DEFAULT == "AMD", "AMD mockgpu with indirect buffers does not support queue wait interrupts") - def test_profile_copy_args(self): - d1 = Device[f"{Device.DEFAULT}:1"] - - def f(a): - x = (a + 1).realize() - return x, x.to(d1.device).realize() - - a = Tensor.randn(10, 10, device=TestProfiler.d0.device).realize() - with helper_collect_profile(TestProfiler.d0, d1) as profile: - jf = TinyJit(f) - for _ in range(3): - TestProfiler.d0.raw_prof_records, TestProfiler.d0.sig_prof_records = [], [] # reset to collect only graph logs - d1.raw_prof_records, d1.sig_prof_records = [], [] - jf(a) - del jf - - node = helper_profile_filter_node(profile, name=f"{Device.DEFAULT} -> {Device.DEFAULT}:1")[-1] - assert node['args']['Size'] == "400.00 B" - assert abs(float(node['args']['GB/S']) - ((10 * 10 * 4) / 1e3) / (node['dur'])) < 0.01 + assert len(graph_evs) == 1, "one graph event is expected" + assert len(graph_evs[0].ents) == 2, "two entities are expected" @unittest.skipIf(CI, "skip CI") - def test_profile_sync(self): - mv = memoryview(bytearray(struct.pack("ff", 0, 1))) - expected_diff = 100000 # sleep in us + def test_dev_jitter_matrix(self): + dev_cnt = 6 + devs = [Device[f"{Device.DEFAULT}:{i}"] for i in range(dev_cnt)] + for dev in devs: dev.synchronize() + for dev in devs: dev._at_profile_finalize() - devs = [Device[f"{Device.DEFAULT}:{i}"] for i in range(6)] - bufs = [Buffer(f"{Device.DEFAULT}:{i}", 2, dtypes.float, options=BufferSpec(nolru=True)).ensure_allocated() for i in range(6)] + def _sync_d2d(d1:HCQCompiled, d2:HCQCompiled): + d1.hw_compute_queue_t().signal(d1.timeline_signal, d1.timeline_value).wait(d2.timeline_signal, d2.timeline_value) \ + .timestamp(d1.timeline_signal).signal(d1.timeline_signal, d1.timeline_value+1).submit(d1) + d2.hw_compute_queue_t().signal(d2.timeline_signal, d2.timeline_value).wait(d1.timeline_signal, d1.timeline_value) \ + .timestamp(d2.timeline_signal).signal(d2.timeline_signal, d2.timeline_value+1).submit(d2) + d1.timeline_value += 2 + d2.timeline_value += 2 + d1.timeline_signal.wait(d1.timeline_value - 1) + d2.timeline_signal.wait(d2.timeline_value - 1) + return d2.timeline_signal.timestamp - d1.timeline_signal.timestamp - # enqueue ops on different queues to check the timer sync - cpu_time = [] - with helper_collect_profile(*devs, random_setup_delay=True) as profile: - for i in range(6): - x = time.perf_counter_ns() - time.sleep(expected_diff / 1e6) - bufs[i].copyin(mv) - cpu_time.append(((time.perf_counter_ns() - x) / 1000) - expected_diff) - - nodes = [helper_profile_filter_node(profile, name=f"CPU -> {Device.canonicalize(f'{Device.DEFAULT}:{i}')}")[-1] for i in range(6)] - avg_diff = [] - for i in range(1, 6): - diff = nodes[i]['ts'] - nodes[i-1]['ts'] - cpu_time[i] - avg_diff.append(diff - expected_diff) - assert expected_diff * 0.998 < diff < expected_diff * 1.002, "more that 0.2% diff" - - print(f"total avg delay is {sum(avg_diff) / len(avg_diff)} us") + # then test it by timing the GPU to GPU times + jitter_matrix = [[float('nan')] * len(devs) for _ in range(len(devs))] + pairs = [(p1, p2) for p1 in enumerate(devs) for p2 in enumerate(devs) if p1 != p2] + for (i1, d1), (i2, d2) in pairs: + cpu_diff = d1.gpu2cpu_compute_time_diff - d2.gpu2cpu_compute_time_diff + jitter_matrix[i1][i2] = statistics.median(_sync_d2d(d1, d2) - _sync_d2d(d2, d1) for _ in range(20)) / 2 - cpu_diff + assert abs(jitter_matrix[i1][i2]) < 0.5, "jitter should be less than 0.5ms" + print("pairwise clock jitter matrix (us):\n" + '\n'.join([''.join([f'{float(item):8.3f}' for item in row]) for row in jitter_matrix])) if __name__ == "__main__": - unittest.main() + unittest.main() \ No newline at end of file diff --git a/test/test_viz.py b/test/test_viz.py index 821b39a35e..1a94e95cd0 100644 --- a/test/test_viz.py +++ b/test/test_viz.py @@ -1,9 +1,10 @@ from typing import Dict, List, Optional -import unittest +import unittest, decimal, json from tinygrad.dtype import dtypes from tinygrad.ops import TRACK_MATCH_STATS, TrackedPatternMatcher as PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites from tinygrad.ops import tracked_ctxs as contexts, tracked_keys as keys -from tinygrad.viz.serve import get_details, get_metadata, uop_to_json +from tinygrad.device import ProfileDeviceEvent, ProfileRangeEvent, ProfileGraphEvent, ProfileGraphEntry +from tinygrad.viz.serve import get_details, get_metadata, uop_to_json, to_perfetto @track_rewrites(named=True) def rewrite(sink:UOp, pm:PatternMatcher, **kwargs): return graph_rewrite(sink, pm, **kwargs) @@ -112,5 +113,77 @@ class TestViz(unittest.TestCase): self.assertEqual(len(ret), 1) self.assertIs(ret[0], a.sqrt().sin()) # only rewrite +class TextVizProfiler(unittest.TestCase): + def test_perfetto_node(self): + prof = [ProfileRangeEvent(device='NV', name='E_2', st=decimal.Decimal(1000), en=decimal.Decimal(1010), is_copy=False), + ProfileDeviceEvent(device='NV', comp_tdiff=decimal.Decimal(-1000), copy_tdiff=decimal.Decimal(-100))] + + j = json.loads(to_perfetto(prof)) + + # Device regs always first + self.assertEqual(j['traceEvents'][0]['name'], 'process_name') + self.assertEqual(j['traceEvents'][0]['ph'], 'M') + self.assertEqual(j['traceEvents'][0]['args']['name'], 'NV') + + self.assertEqual(j['traceEvents'][1]['name'], 'thread_name') + self.assertEqual(j['traceEvents'][1]['ph'], 'M') + self.assertEqual(j['traceEvents'][1]['pid'], j['traceEvents'][0]['pid']) + self.assertEqual(j['traceEvents'][1]['tid'], 0) + self.assertEqual(j['traceEvents'][1]['args']['name'], 'COMPUTE') + + self.assertEqual(j['traceEvents'][2]['name'], 'thread_name') + self.assertEqual(j['traceEvents'][2]['ph'], 'M') + self.assertEqual(j['traceEvents'][2]['pid'], j['traceEvents'][0]['pid']) + self.assertEqual(j['traceEvents'][2]['tid'], 1) + self.assertEqual(j['traceEvents'][2]['args']['name'], 'COPY') + + self.assertEqual(j['traceEvents'][3]['name'], 'E_2') + self.assertEqual(j['traceEvents'][3]['ts'], 0) + self.assertEqual(j['traceEvents'][3]['dur'], 10) + self.assertEqual(j['traceEvents'][3]['ph'], 'X') + self.assertEqual(j['traceEvents'][3]['pid'], j['traceEvents'][0]['pid']) + self.assertEqual(j['traceEvents'][3]['tid'], 0) + + def test_perfetto_copy_node(self): + prof = [ProfileRangeEvent(device='NV', name='COPYxx', st=decimal.Decimal(1000), en=decimal.Decimal(1010), is_copy=True), + ProfileDeviceEvent(device='NV', comp_tdiff=decimal.Decimal(-1000), copy_tdiff=decimal.Decimal(-100))] + + j = json.loads(to_perfetto(prof)) + + self.assertEqual(j['traceEvents'][3]['name'], 'COPYxx') + self.assertEqual(j['traceEvents'][3]['ts'], 900) # diff clock + self.assertEqual(j['traceEvents'][3]['dur'], 10) + self.assertEqual(j['traceEvents'][3]['ph'], 'X') + self.assertEqual(j['traceEvents'][3]['tid'], 1) + + def test_perfetto_graph(self): + prof = [ProfileDeviceEvent(device='NV', comp_tdiff=decimal.Decimal(-1000), copy_tdiff=decimal.Decimal(-100)), + ProfileDeviceEvent(device='NV:1', comp_tdiff=decimal.Decimal(-500), copy_tdiff=decimal.Decimal(-50)), + ProfileGraphEvent(ents=[ProfileGraphEntry(device='NV', name='E_25_4n2', st_id=0, en_id=1, is_copy=False), + ProfileGraphEntry(device='NV:1', name='NV -> NV:1', st_id=2, en_id=3, is_copy=True)], + deps=[[], [0]], + sigs=[decimal.Decimal(1000), decimal.Decimal(1002), decimal.Decimal(1004), decimal.Decimal(1008)])] + + j = json.loads(to_perfetto(prof)) + + # Device regs always first + self.assertEqual(j['traceEvents'][0]['args']['name'], 'NV') + self.assertEqual(j['traceEvents'][1]['args']['name'], 'COMPUTE') + self.assertEqual(j['traceEvents'][2]['args']['name'], 'COPY') + self.assertEqual(j['traceEvents'][3]['args']['name'], 'NV:1') + self.assertEqual(j['traceEvents'][4]['args']['name'], 'COMPUTE') + self.assertEqual(j['traceEvents'][5]['args']['name'], 'COPY') + + self.assertEqual(j['traceEvents'][6]['name'], 'E_25_4n2') + self.assertEqual(j['traceEvents'][6]['ts'], 0) + self.assertEqual(j['traceEvents'][6]['dur'], 2) + self.assertEqual(j['traceEvents'][6]['pid'], j['traceEvents'][0]['pid']) + + self.assertEqual(j['traceEvents'][7]['name'], 'NV -> NV:1') + self.assertEqual(j['traceEvents'][7]['ts'], 954) + self.assertEqual(j['traceEvents'][7]['dur'], 4) + self.assertEqual(j['traceEvents'][7]['pid'], j['traceEvents'][3]['pid']) + + if __name__ == "__main__": unittest.main() diff --git a/tinygrad/device.py b/tinygrad/device.py index 6d68785bc6..56d51fccff 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -1,9 +1,9 @@ from __future__ import annotations from dataclasses import dataclass, replace from collections import defaultdict -from typing import Optional, Dict, Tuple, Any, Iterator -import multiprocessing, importlib, inspect, functools, pathlib, os, ctypes, contextlib, sys, re -from tinygrad.helpers import CI, OSX, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv +from typing import Optional, Dict, Tuple, Any, Iterator, List, Set +import multiprocessing, importlib, inspect, functools, pathlib, os, ctypes, contextlib, sys, re, atexit, pickle, decimal +from tinygrad.helpers import CI, OSX, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv, PROFILE, temp from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes from tinygrad.renderer import Renderer from tinygrad.ops import UOp, buffers @@ -13,6 +13,7 @@ from tinygrad.ops import UOp, buffers class _Device: def __init__(self) -> None: self._devices = [x.stem[len("ops_"):].upper() for x in (pathlib.Path(__file__).parent/"runtime").iterdir() if x.stem.startswith("ops_")] + self._opened_devices:Set[str] = set() @functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none def _canonicalize(self, device:str) -> str: return re.sub(r":0$", "", (d:=device.split(":", 1)[0].upper()) + device[len(d):]) # NOTE: you can't cache canonicalize in case Device.DEFAULT changes @@ -26,6 +27,7 @@ class _Device: ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'{__name__.split(".")[0]}.runtime.ops_{x.lower()}')) \ if (cname.lower() == x.lower() + "device")][0](ix) if DEBUG >= 1: print(f"opened device {ix} from pid:{os.getpid()}") + self._opened_devices.add(ix) return ret @property def default(self) -> Compiled: return self[self.DEFAULT] @@ -42,6 +44,23 @@ class _Device: except StopIteration as exc: raise RuntimeError("no usable devices") from exc Device = _Device() +# **************** Profile **************** + +class ProfileEvent: pass + +@dataclass(frozen=True) +class ProfileDeviceEvent(ProfileEvent): + device:str; comp_tdiff:decimal.Decimal=decimal.Decimal(0); copy_tdiff:decimal.Decimal=decimal.Decimal(0) # noqa: E702 + +@dataclass(frozen=True) +class ProfileRangeEvent(ProfileEvent): device:str; name:str; st:decimal.Decimal; en:decimal.Decimal; is_copy:bool # noqa: E702 + +@dataclass(frozen=True) +class ProfileGraphEntry: device:str; name:str; st_id:int; en_id:int; is_copy:bool # noqa: E702 + +@dataclass(frozen=True) +class ProfileGraphEvent(ProfileEvent): ents:List[ProfileGraphEntry]; deps:List[List[int]]; sigs:List[decimal.Decimal] # noqa: E702 + # **************** Buffer + Allocators **************** @@ -202,6 +221,8 @@ class Compiler: def disassemble(self, lib:bytes): pass class Compiled: + profile_events:List[ProfileEvent] = [] + def __init__(self, device:str, allocator:Allocator, renderer:Optional[Renderer], compiler:Optional[Compiler], runtime, graph=None): self.device, self.allocator, self.compiler, self.runtime, self.graph = device, allocator, compiler or Compiler(), runtime, graph self.renderer = renderer or Renderer() @@ -212,6 +233,11 @@ class Compiled: This method ensures that all previously queued operations on the device have been completed before proceeding. """ # override this in your device implementation + def _at_profile_finalize(self): + """ + Called at the end of profiling to allow the device to finalize any profiling. + """ + # override this in your device implementation # TODO: move this to each Device def is_dtype_supported(dtype:DType, device:Optional[str]=None) -> bool: @@ -232,3 +258,15 @@ def is_dtype_supported(dtype:DType, device:Optional[str]=None) -> bool: if device == "PYTHON": return sys.version_info >= (3, 12) if dtype == dtypes.float64: return device != "METAL" and not (OSX and device == "GPU") return True + +if PROFILE: + @atexit.register + def finlize_profile(): + devs = [Device[d] for d in Device._opened_devices] + for dev in devs: dev.synchronize() + for dev in devs: dev._at_profile_finalize() + + with open(temp("profile.pkl"), "wb") as f: pickle.dump(Compiled.profile_events, f) + + from tinygrad.ops import launch_viz + launch_viz("PROFILE", temp("profile.pkl")) diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 4e9a273979..03f7c653c7 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -97,8 +97,7 @@ class ContextVar: def __lt__(self, x): return self.value < x DEBUG, IMAGE, BEAM, NOOPT, JIT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0), ContextVar("JIT", 1) -WINO, CAPTURING, TRACEMETA = ContextVar("WINO", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1) -PROFILE, PROFILEPATH = ContextVar("PROFILE", 0), ContextVar("PROFILEPATH", temp("tinygrad_profile.json")) +WINO, CAPTURING, TRACEMETA, PROFILE = ContextVar("WINO", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1), ContextVar("PROFILE", 0) USE_TC, TC_OPT, AMX, TRANSCENDENTAL = ContextVar("TC", 1), ContextVar("TC_OPT", 0), ContextVar("AMX", 0), ContextVar("TRANSCENDENTAL", 1) FUSE_ARANGE, FUSE_CONV_BW = ContextVar("FUSE_ARANGE", 0), ContextVar("FUSE_CONV_BW", 0) SPLIT_REDUCEOP, NO_MEMORY_PLANNER, RING = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("RING", 1) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index ec3157ecce..13f3c0531a 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -891,9 +891,7 @@ if TRACK_MATCH_STATS: with open(fn:=temp("rewrites.pkl"), "wb") as f: print(f"rewrote {len(tracked_ctxs)} graphs and matched {sum(len(r.matches) for x in tracked_ctxs for r in x)} times, saved to {fn}") pickle.dump((tracked_keys, tracked_ctxs), f) - if getenv("VIZ"): - os.environ["VIZ"] = "0" - os.execv(sys.executable, [sys.executable] + [os.path.join(os.path.dirname(__file__), ".", "viz", "serve.py"), temp("rewrites.pkl")]) + launch_viz("VIZ", temp("rewrites.pkl")) if getenv("PRINT_MATCH_STATS", 1): ret = [0,0,0.0,0.0] for k,v in sorted(list(match_stats.items()), key=lambda x: x[1][2]+x[1][3]): @@ -902,6 +900,14 @@ if TRACK_MATCH_STATS: ret = [x+y for x,y in zip(ret, v)] print(f"{ret[0]:6d} / {ret[1]:7d} -- {ret[3]*1000.:9.2f} / {(ret[2]+ret[3])*1000.:9.2f} ms -- TOTAL") +def launch_viz(env_str:str, data:str): + os.environ[env_str] = "0" + os.environ[f"{env_str}_DATA"] = data + if not int(os.getenv("VIZ", "0")) and not int(os.getenv("PROFILE", "0")): + args = ['--kernels', getenv("VIZ_DATA", "")] if getenv("VIZ_DATA", "") else [] + args += ['--profile', getenv("PROFILE_DATA", "")] if getenv("PROFILE_DATA", "") else [] + os.execv(sys.executable, [sys.executable] + [os.path.join(os.path.dirname(__file__), ".", "viz", "serve.py")] + args) + # *** simple graph rewrite engine *** class RewriteContext: diff --git a/tinygrad/runtime/graph/hcq.py b/tinygrad/runtime/graph/hcq.py index 44d64f811b..4ed747112e 100644 --- a/tinygrad/runtime/graph/hcq.py +++ b/tinygrad/runtime/graph/hcq.py @@ -1,8 +1,8 @@ import collections, time from typing import List, Any, Dict, cast, Optional, Tuple, Set -from tinygrad.helpers import round_up, PROFILE, memsize_to_str +from tinygrad.helpers import round_up, PROFILE from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQSignal, HCQBuffer, HWQueue, HCQArgsState, BumpAllocator -from tinygrad.device import Buffer, BufferSpec, Compiled, Device +from tinygrad.device import Buffer, BufferSpec, Compiled, Device, ProfileGraphEntry, ProfileGraphEvent from tinygrad.dtype import dtypes from tinygrad.ops import UOp, Variable from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner @@ -51,8 +51,11 @@ class HCQGraph(MultiGraphRunner): self.kickoff_value: int = 0 self.kickoff_var = UOp.variable("kickoff_var", 0, 0xffffffff, dtype=dtypes.uint32) + # When profiling allocate 2 signals for each jit item to measure speed. The jth jit item have signals at 2*j and 2*j+1. + # TODO: This logic might allocate a few extra signals... self.prof_signals: List[HCQSignal] = [self.devices[0].signal_t() for i in range(len(jit_cache) * 2)] if PROFILE else [] - self.prof_records: List[Tuple[Tuple[int, bool], Tuple[int, bool], HCQCompiled, str, bool, List[int], Optional[Dict]]] = [] + self.prog_graph_deps: List[List[int]] = [] + self.prof_graph_entries: List[ProfileGraphEntry] = [] last_j: Dict[HWQueue, Optional[int]] = collections.defaultdict(lambda: None) queue_access: Dict[HWQueue, Dict[HWQueue, Optional[int]]] = collections.defaultdict(lambda: collections.defaultdict(lambda: None)) @@ -102,18 +105,20 @@ class HCQGraph(MultiGraphRunner): # Collect profile information if profiling is enabled. if PROFILE: + # When execution are chained, we can reuse the end timestamp from the previous command as the start timestamp for the current command. + sig_st = prev_ji * 2 + 1 if len(opt_deps) == 0 and (prev_ji:=last_j[enqueue_queue]) is not None else j * 2 + + # Description based on the command. prof_ji_desc = ji.prg._prg.name if is_exec_prg else f"{ji.bufs[1].device} -> {ji.bufs[0].device}" # type: ignore - sig_st, sig_en = (j * 2, True), (j * 2 + 1, True) - if len(opt_deps) == 0 and (prev_ji:=last_j[enqueue_queue]) is not None: sig_st = (prev_ji * 2 + 1, False) - - if is_exec_prg: prof_args = None - else: prof_args = {"Size": memsize_to_str(ji.bufs[0].nbytes), "GB/S": lambda dur, b=ji.bufs[0].nbytes: f"{b/1e3/dur:.2f}"} # type: ignore - - self.prof_records.append((sig_st, sig_en, enqueue_dev, prof_ji_desc, not is_exec_prg, [d - 1 for _, d in rdeps], prof_args)) + self.prof_graph_entries.append(ProfileGraphEntry(enqueue_dev.device, prof_ji_desc, sig_st, j * 2 + 1, is_copy=not is_exec_prg)) + self.prog_graph_deps.append([d - 1 for _, d in rdeps]) last_j[enqueue_queue] = j + # Check which signals are used in the profile graph. + self.prof_signal_is_used = [any(ent.st_id == j or ent.en_id == j for ent in self.prof_graph_entries) for j in range(len(self.prof_signals))] + # Build hardware queues. self.copy_to_devs: Dict[HCQCompiled, Set[HCQCompiled]] = {dev: set() for dev in self.devices} @@ -132,7 +137,7 @@ class HCQGraph(MultiGraphRunner): for sig, val in sync_signals + deps: enqueue_queue.wait(sig, val) # Encode waits and start profile timestamp (if needed). - if PROFILE and self.prof_records[j][0][1]: enqueue_queue.timestamp(self.prof_signals[self.prof_records[j][0][0]]) + if PROFILE and self.prof_signal_is_used[j * 2]: enqueue_queue.timestamp(self.prof_signals[j * 2]) # Encode main commands based on ji type. if isinstance(ji.prg, CompiledRunner): @@ -145,7 +150,7 @@ class HCQGraph(MultiGraphRunner): self.copy_to_devs[cast(HCQCompiled, Device[dest.device])].add(cast(HCQCompiled, Device[src.device])) # Encode finish profile timestamp (if needed). - if PROFILE and self.prof_records[j][1][1]: enqueue_queue.timestamp(self.prof_signals[self.prof_records[j][1][0]]) + if PROFILE and self.prof_signal_is_used[j * 2 + 1]: enqueue_queue.timestamp(self.prof_signals[j * 2 + 1]) if signal_val is not None: enqueue_queue.signal(signal, signal_val) @@ -189,14 +194,8 @@ class HCQGraph(MultiGraphRunner): return None def collect_timestamps(self): - timestamps = [s.timestamp for s in self.prof_signals] - - for (st,_), (en,_), dev, desc, is_cp, deps, args in self.prof_records: - dev.raw_prof_records += [(timestamps[st], timestamps[en], desc, is_cp, args)] - - for x in deps: - (b_st,_), (b_en,_), b_dev, _, b_is_cp, _, _ = self.prof_records[x] - dev.dep_prof_records += [(timestamps[b_st], timestamps[b_en], b_dev, b_is_cp, timestamps[st], timestamps[en], dev, is_cp)] + # NOTE: Append to any device is fine... + self.devices[0].profile_events += [ProfileGraphEvent(self.prof_graph_entries, self.prog_graph_deps, [s.timestamp for s in self.prof_signals])] def __del__(self): for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1]) diff --git a/tinygrad/runtime/support/hcq.py b/tinygrad/runtime/support/hcq.py index a267714dad..2d6448f51d 100644 --- a/tinygrad/runtime/support/hcq.py +++ b/tinygrad/runtime/support/hcq.py @@ -1,9 +1,9 @@ from __future__ import annotations -from typing import List, Optional, Dict, Tuple, cast, Type, Union, TypeVar, Generic, Any -import contextlib, decimal, statistics, random, json, atexit, time, ctypes, array -from tinygrad.helpers import PROFILEPATH, PROFILE, from_mv, getenv, to_mv, round_up +from typing import List, Optional, Dict, Tuple, cast, Type, TypeVar, Generic, Any +import contextlib, decimal, statistics, time, ctypes, array +from tinygrad.helpers import PROFILE, from_mv, getenv, to_mv, round_up from tinygrad.renderer import Renderer -from tinygrad.device import BufferSpec, Compiler, Compiled, LRUAllocator +from tinygrad.device import BufferSpec, Compiler, Compiled, LRUAllocator, ProfileRangeEvent, ProfileDeviceEvent from tinygrad.ops import sym_infer, sint, Variable # **************** for HCQ Compatible Devices **************** @@ -294,51 +294,11 @@ class HCQProgram(Generic[DeviceType]): if wait: self.dev.synchronize() return (float(sig_en.timestamp - sig_st.timestamp) / 1e6) if wait else None -class ProfileLogger: - writers: int = 0 - mjson: List[Dict] = [] - actors: Dict[Union[str, Tuple[str, str]], int] = {} - - def __init__(self): self.events, self.deps, ProfileLogger.writers = [], [], ProfileLogger.writers + 1 - - def add_event(self, ev_name, ev_start, ev_end, actor, subactor=None, args=None): self.events += [(ev_name, ev_start, ev_end, actor, subactor, args)] - - def _ensure_actor(self, actor_name, subactor_name): - if actor_name not in self.actors: - self.actors[actor_name] = (pid:=len(self.actors)) - self.mjson.append({"name": "process_name", "ph": "M", "pid": pid, "args": {"name": actor_name}}) - - if (subactor_key:=(actor_name,subactor_name)) not in self.actors: - self.actors[subactor_key] = (tid:=len(self.actors)) - self.mjson.append({"name": "thread_name", "ph": "M", "pid": self.actors[actor_name], "tid":tid, "args": {"name": subactor_name}}) - - return self.actors[actor_name], self.actors.get(subactor_key, -1) - - def __del__(self): - # perfetto json docs: https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/preview - for name, st, et, actor_name, subactor_name, args in self.events: - pid, tid = self._ensure_actor(actor_name,subactor_name) - args = {k: (v if v.__class__ is str else v(et-st)) for k, v in args.items()} if args is not None else None - self.mjson.append({"name": name, "ph": "X", "pid": pid, "tid": tid, "ts": st, "dur": et-st, "args": args}) - - for en,st,dep_actor_name,dep_subactor_name,actor_name,subactor_name in self.deps: - dep_pid, dep_tid = self._ensure_actor(dep_actor_name,dep_subactor_name) - pid, tid = self._ensure_actor(actor_name,subactor_name) - self.mjson.append({"ph": "s", "pid": dep_pid, "tid": dep_tid, "id": len(self.mjson), "ts": en, "bp": "e"}) - self.mjson.append({"ph": "f", "pid": pid, "tid": tid, "id": len(self.mjson)-1, "ts": st, "bp": "e"}) - - ProfileLogger.writers -= 1 - if ProfileLogger.writers == 0 and len(self.mjson) > 0: - with open(PROFILEPATH.value, "w") as f: f.write(json.dumps({"traceEvents": self.mjson})) - print(f"Saved profile to {PROFILEPATH.value}. Use https://ui.perfetto.dev/ to open it.") - class HCQCompiled(Compiled, Generic[SignalType]): """ A base class for devices compatible with the HCQ (Hardware Command Queue) API. """ devices: List[HCQCompiled] = [] - gpu2cpu_copy_time_diff: decimal.Decimal = decimal.Decimal('nan') - gpu2cpu_compute_time_diff: decimal.Decimal = decimal.Decimal('nan') def __init__(self, device:str, allocator:HCQAllocatorBase, renderer:Renderer, compiler:Compiler, runtime, signal_t:Type[SignalType], comp_queue_t:Type[HWQueue], copy_queue_t:Optional[Type[HWQueue]]): @@ -350,7 +310,6 @@ class HCQCompiled(Compiled, Generic[SignalType]): self.sig_prof_records:List[Tuple[HCQSignal, HCQSignal, str, bool]] = [] self.raw_prof_records:List[Tuple[decimal.Decimal, decimal.Decimal, str, bool, Optional[Dict]]] = [] self.dep_prof_records:List[Tuple[decimal.Decimal, decimal.Decimal, HCQCompiled, bool, decimal.Decimal, decimal.Decimal, HCQCompiled, bool]] = [] - if PROFILE: self._prof_setup() from tinygrad.runtime.graph.hcq import HCQGraph super().__init__(device, allocator, renderer, compiler, runtime, HCQGraph) @@ -367,13 +326,11 @@ class HCQCompiled(Compiled, Generic[SignalType]): if self.timeline_value > (1 << 31): self._wrap_timeline_signal() if PROFILE: - self.raw_prof_records += [(st.timestamp, en.timestamp, name, is_cp, None) for st, en, name, is_cp in self.sig_prof_records] + Compiled.profile_events += [ProfileRangeEvent(self.device, name, st.timestamp, en.timestamp, cp) for st,en,name,cp in self.sig_prof_records] self.sig_prof_records = [] - def _ensure_shared_time_base(self): - if not self.gpu2cpu_compute_time_diff.is_nan(): return - - def _sync_cpu_queue(d:HCQCompiled, q_t:Type[HWQueue]): + def _at_profile_finalize(self): + def _sync(d:HCQCompiled, q_t:Type[HWQueue]): q_t().timestamp(d.timeline_signal).signal(d.timeline_signal, d.timeline_value).submit(d) d.timeline_value += 1 st = time.perf_counter_ns() @@ -381,65 +338,10 @@ class HCQCompiled(Compiled, Generic[SignalType]): et = time.perf_counter_ns() return (decimal.Decimal(et+st) / 2000) - d.timeline_signal.timestamp - # randomly sample the timing from GPU to CPU - choices: List = [(d, d.hw_compute_queue_t, []) for d in self.devices] - choices += [(d, d.hw_copy_queue_t, []) for d in self.devices if d.hw_copy_queue_t is not None] - for _ in range(100*len(self.devices)): - d,q,l = random.choice(choices) - l.append(_sync_cpu_queue(d,q)) - for d,q,l in choices: - if q == d.hw_compute_queue_t: d.gpu2cpu_compute_time_diff = statistics.median(l) - if q == d.hw_copy_queue_t: d.gpu2cpu_copy_time_diff = statistics.median(l) - - def _sync_gpu_to_gpu_queue(d1:HCQCompiled, d2:HCQCompiled, q1_t:Type[HWQueue], q2_t:Type[HWQueue]): - q1_t().signal(d1.timeline_signal, d1.timeline_value).wait(d2.timeline_signal, d2.timeline_value) \ - .timestamp(d1.timeline_signal).signal(d1.timeline_signal, d1.timeline_value+1).submit(d1) - q2_t().signal(d2.timeline_signal, d2.timeline_value).wait(d1.timeline_signal, d1.timeline_value) \ - .timestamp(d2.timeline_signal).signal(d2.timeline_signal, d2.timeline_value+1).submit(d2) - d1.timeline_value += 2 - d2.timeline_value += 2 - d1.timeline_signal.wait(d1.timeline_value - 1) - d2.timeline_signal.wait(d2.timeline_value - 1) - return d2.timeline_signal.timestamp - d1.timeline_signal.timestamp - - # then test it by timing the GPU to GPU times - jitter_matrix = [[float('nan')]*len(self.devices) for _ in range(len(self.devices))] - for i1, d1 in enumerate(self.devices): - for i2, d2 in enumerate(self.devices): - if d1 == d2: continue - d1_to_d2 = statistics.median(_sync_gpu_to_gpu_queue(d1, d2, d1.hw_compute_queue_t, d2.hw_compute_queue_t) - \ - _sync_gpu_to_gpu_queue(d2, d1, d2.hw_compute_queue_t, d1.hw_compute_queue_t) for _ in range(20)) / 2 - jitter_matrix[i1][i2] = d1_to_d2 - (d1.gpu2cpu_compute_time_diff - d2.gpu2cpu_compute_time_diff) - print("pairwise clock jitter matrix (us):\n" + '\n'.join([''.join([f'{float(item):8.3f}' for item in row]) for row in jitter_matrix])) - - def _gpu2cpu_time(self, gpu_time:decimal.Decimal, is_copy:bool) -> float: - """ - Translates local gpu time (timestamp) into global cpu time. - """ - self._ensure_shared_time_base() - return float(gpu_time + (self.gpu2cpu_copy_time_diff if is_copy else self.gpu2cpu_compute_time_diff)) - - def _prof_setup(self): - if hasattr(self, 'profile_logger'): return - atexit.register(self._prof_finalize) - self.profile_logger = ProfileLogger() - - def _prof_finalize(self): - qname = ["COMPUTE", "DMA"] - - # Sync to be sure all events on the device are recorded. - self.synchronize() - - for st, en, name, is_cp, args in self.raw_prof_records: - self.profile_logger.events += [(name, self._gpu2cpu_time(st, is_cp), self._gpu2cpu_time(en, is_cp), self.device, qname[is_cp], args)] - for a_st, a_en, a_dev, a_is_copy, b_st, b_en, b_dev, b_is_copy in self.dep_prof_records: - # Perfetto connects nodes based on timing data, ensuring every choice is valid by averaging times to a midpoint. - a_tm, b_tm = a_dev._gpu2cpu_time((a_st+a_en)/decimal.Decimal(2), a_is_copy), b_dev._gpu2cpu_time((b_st+b_en)/decimal.Decimal(2), b_is_copy) - self.profile_logger.deps += [(a_tm, b_tm, a_dev.device, qname[a_is_copy], b_dev.device, qname[b_is_copy])] - self.raw_prof_records, self.dep_prof_records = [], [] - - # Remove the logger, this flushes all data written by the device. - del self.profile_logger + gpu2cpu_compute_time_diff = statistics.median([_sync(self, self.hw_compute_queue_t) for _ in range(40)]) + if self.hw_copy_queue_t is None: gpu2cpu_copy_time_diff = decimal.Decimal(0) + else: gpu2cpu_copy_time_diff = statistics.median([_sync(self, self.hw_copy_queue_t) for _ in range(40)]) + Compiled.profile_events += [ProfileDeviceEvent(self.device, gpu2cpu_compute_time_diff, gpu2cpu_copy_time_diff)] def _wrap_timeline_signal(self): self.timeline_signal, self._shadow_timeline_signal, self.timeline_value = self._shadow_timeline_signal, self.timeline_signal, 1 diff --git a/tinygrad/viz/index.html b/tinygrad/viz/index.html index 7c63ceb533..f42ac724e8 100644 --- a/tinygrad/viz/index.html +++ b/tinygrad/viz/index.html @@ -130,19 +130,46 @@ #metadata-resize-handle { left: 0; } - .collapse-btn { + .floating-container { position: fixed; top: 10px; left: 20px; + z-index: 4; + display: flex; + flex-direction: row; + gap: 8px; + } + .nav-btn { background-color: #1a1b26; border: 1px solid #4a4b56; color: #f0f0f5; - width: 32px; height: 32px; - padding: 6px; border-radius: 8px; cursor: pointer; - z-index: 4; + text-decoration: none; + display: flex; + align-items: center; + padding: 0 6px; + font-weight: bold; + } + .collapse-btn { + width: 32px; + padding: 6px; + } + .btn { + height: 32px; + background-color: #1a1b26; + border: 1px solid #4a4b56; + color: #f0f0f5; + border-radius: 8px; + cursor: pointer; + transition-duration: .5s; + } + .btn:hover { + background-color: #2a2b36; + border-color: #5a5b66; + color: #ffffff; + transform: translateY(-1px); } .collapsed .kernel-list, .collapsed .metadata { width: 0; @@ -170,9 +197,12 @@
- +
+ + Profiler +
diff --git a/tinygrad/viz/perfetto.html b/tinygrad/viz/perfetto.html new file mode 100644 index 0000000000..a2d55c2945 --- /dev/null +++ b/tinygrad/viz/perfetto.html @@ -0,0 +1,178 @@ + + + + + + +
+
+ UOps +
+ +
+
+
Loading trace data...
+
+ + + + + + \ No newline at end of file diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index a678637bcc..62c64d38c1 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -import multiprocessing, pickle, functools, difflib, os, threading, json, time, sys, webbrowser, socket +import multiprocessing, pickle, functools, difflib, os, threading, json, time, sys, webbrowser, socket, argparse, decimal from http.server import HTTPServer, BaseHTTPRequestHandler from urllib.parse import parse_qs, urlparse from dataclasses import asdict, dataclass @@ -7,6 +7,7 @@ from typing import Any, Callable, Dict, List, Tuple, Optional from tinygrad.helpers import colored, getenv, to_function_name, tqdm, unwrap, word_wrap from tinygrad.ops import TrackedGraphRewrite, UOp, Ops, lines, GroupOp from tinygrad.codegen.kernel import Kernel +from tinygrad.device import ProfileEvent, ProfileDeviceEvent, ProfileRangeEvent, ProfileGraphEvent uops_colors = {Ops.LOAD: "#ffc0c0", Ops.PRELOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.DEFINE_GLOBAL: "#ffe0b0", Ops.DEFINE_LOCAL: "#ffe0d0", Ops.DEFINE_ACC: "#f0ffe0", Ops.REDUCE_AXIS: "#FF6B6B", @@ -49,7 +50,7 @@ def pcall(fxn:Callable[..., str], *args, **kwargs) -> str: def get_metadata(keys:List[Any], contexts:List[List[TrackedGraphRewrite]]) -> List[List[Tuple[Any, TrackedGraphRewrite, GraphRewriteMetadata]]]: kernels: Dict[str, List[Tuple[Any, TrackedGraphRewrite, GraphRewriteMetadata]]] = {} - for k,ctxs in zip(keys, contexts): + for k,ctxs in tqdm(zip(keys, contexts), desc="preparing kernels"): name = to_function_name(k.name) if isinstance(k, Kernel) else str(k) for ctx in ctxs: if pickle.loads(ctx.sink).op is Ops.CONST: continue @@ -99,6 +100,35 @@ def get_details(k:Any, ctx:TrackedGraphRewrite, metadata:GraphRewriteMetadata) - g.graphs.append(sink:=new_sink) return g +# Profiler API +devices:Dict[str, Tuple[decimal.Decimal, decimal.Decimal, int]] = {} +def prep_ts(device:str, ts:decimal.Decimal, is_copy): return int(decimal.Decimal(ts) + devices[device][is_copy]) +def dev_to_pid(device:str, is_copy=False): return {"pid": devices[device][2], "tid": int(is_copy)} +def dev_ev_to_perfetto_json(ev:ProfileDeviceEvent): + devices[ev.device] = (ev.comp_tdiff, ev.copy_tdiff if ev.copy_tdiff is not None else ev.comp_tdiff, len(devices)) + return [{"name": "process_name", "ph": "M", "pid": dev_to_pid(ev.device)['pid'], "args": {"name": ev.device}}, + {"name": "thread_name", "ph": "M", "pid": dev_to_pid(ev.device)['pid'], "tid": 0, "args": {"name": "COMPUTE"}}, + {"name": "thread_name", "ph": "M", "pid": dev_to_pid(ev.device)['pid'], "tid": 1, "args": {"name": "COPY"}}] +def range_ev_to_perfetto_json(ev:ProfileRangeEvent): + return [{"name": ev.name, "ph": "X", "ts": prep_ts(ev.device, ev.st, ev.is_copy), "dur": float(ev.en-ev.st), **dev_to_pid(ev.device, ev.is_copy)}] +def graph_ev_to_perfetto_json(ev:ProfileGraphEvent, reccnt): + ret = [] + for i,e in enumerate(ev.ents): + st, en = ev.sigs[e.st_id], ev.sigs[e.en_id] + ret += [{"name": e.name, "ph": "X", "ts": prep_ts(e.device, st, e.is_copy), "dur": float(en-st), **dev_to_pid(e.device, e.is_copy)}] + for dep in ev.deps[i]: + d = ev.ents[dep] + ret += [{"ph": "s", **dev_to_pid(d.device, d.is_copy), "id": reccnt+len(ret), "ts": prep_ts(d.device, ev.sigs[d.en_id], d.is_copy), "bp": "e"}] + ret += [{"ph": "f", **dev_to_pid(e.device, e.is_copy), "id": reccnt+len(ret)-1, "ts": prep_ts(e.device, st, e.is_copy), "bp": "e"}] + return ret +def to_perfetto(profile:List[ProfileEvent]): + # Start json with devices. + prof_json = [x for ev in profile if isinstance(ev, ProfileDeviceEvent) for x in dev_ev_to_perfetto_json(ev)] + for ev in tqdm(profile, desc="preparing profile"): + if isinstance(ev, ProfileRangeEvent): prof_json += range_ev_to_perfetto_json(ev) + elif isinstance(ev, ProfileGraphEvent): prof_json += graph_ev_to_perfetto_json(ev, reccnt=len(prof_json)) + return json.dumps({"traceEvents": prof_json}).encode() if len(prof_json) > 0 else None + # ** HTTP server class Handler(BaseHTTPRequestHandler): @@ -107,6 +137,8 @@ class Handler(BaseHTTPRequestHandler): if (url:=urlparse(self.path)).path == "/": with open(os.path.join(os.path.dirname(__file__), "index.html"), "rb") as f: ret = f.read() + elif (url:=urlparse(self.path)).path == "/profiler": + with open(os.path.join(os.path.dirname(__file__), "perfetto.html"), "rb") as f: ret = f.read() elif self.path.startswith("/assets/") and '/..' not in self.path: try: with open(os.path.join(os.path.dirname(__file__), self.path.strip('/')), "rb") as f: ret = f.read() @@ -120,6 +152,7 @@ class Handler(BaseHTTPRequestHandler): jret: Any = {**asdict(g), "graphs": [uop_to_json(x) for x in g.graphs], "uops": [pcall(str,x) for x in g.graphs]} else: jret = [list(map(lambda x:asdict(x[2]), v)) for v in kernels] ret, content_type = json.dumps(jret).encode(), "application/json" + elif url.path == "/get_profile" and perfetto_profile is not None: ret, content_type = perfetto_profile, "application/json" else: status_code = 404 # send response @@ -139,7 +172,16 @@ def reloader(): os.execv(sys.executable, [sys.executable] + sys.argv) time.sleep(0.1) +def load_pickle(path:str): + if path is None or not os.path.exists(path): return None + with open(path, "rb") as f: return pickle.load(f) + if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--kernels', type=str, help='Path to kernels', default=None) + parser.add_argument('--profile', type=str, help='Path profile', default=None) + args = parser.parse_args() + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: if s.connect_ex(((HOST:="http://127.0.0.1").replace("http://", ""), PORT:=getenv("PORT", 8000))) == 0: raise RuntimeError(f"{HOST}:{PORT} is occupied! use PORT= to change.") @@ -147,19 +189,23 @@ if __name__ == "__main__": multiprocessing.current_process().name = "VizProcess" # disallow opening of devices st = time.perf_counter() print("*** viz is starting") - with open(sys.argv[1], "rb") as f: contexts: Tuple[List[Any], List[List[TrackedGraphRewrite]]] = pickle.load(f) - print("*** unpickled saved rewrites") - kernels = get_metadata(*contexts) + + contexts, profile = load_pickle(args.kernels), load_pickle(args.profile) + + kernels = get_metadata(*contexts) if contexts is not None else [] + if getenv("FUZZ_VIZ"): ret = [get_details(*args) for v in tqdm(kernels) for args in v] print(f"fuzzed {len(ret)} rewrite details") - print("*** loaded kernels") + + perfetto_profile = to_perfetto(profile) if profile is not None else None + server = HTTPServer(('', PORT), Handler) reloader_thread = threading.Thread(target=reloader) reloader_thread.start() print(f"*** started viz on {HOST}:{PORT}") print(colored(f"*** ready in {(time.perf_counter()-st)*1e3:4.2f}ms", "green")) - if getenv("BROWSER", 0): webbrowser.open(f"{HOST}:{PORT}") + if len(getenv("BROWSER", "")) > 0: webbrowser.open(f"{HOST}:{PORT}{'/profiler' if contexts is None else ''}") try: server.serve_forever() except KeyboardInterrupt: print("*** viz is shutting down...")