mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
viz profiler (#8287)
* only hcq * fix get_metadata * linter * oops * tiny * linter * time * print pm * hmm * nits
This commit is contained in:
@@ -1,9 +1,10 @@
|
||||
from typing import Dict, List, Optional
|
||||
import unittest
|
||||
import unittest, decimal, json
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import TRACK_MATCH_STATS, TrackedPatternMatcher as PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites
|
||||
from tinygrad.ops import tracked_ctxs as contexts, tracked_keys as keys
|
||||
from tinygrad.viz.serve import get_details, get_metadata, uop_to_json
|
||||
from tinygrad.device import ProfileDeviceEvent, ProfileRangeEvent, ProfileGraphEvent, ProfileGraphEntry
|
||||
from tinygrad.viz.serve import get_details, get_metadata, uop_to_json, to_perfetto
|
||||
|
||||
@track_rewrites(named=True)
|
||||
def rewrite(sink:UOp, pm:PatternMatcher, **kwargs): return graph_rewrite(sink, pm, **kwargs)
|
||||
@@ -112,5 +113,77 @@ class TestViz(unittest.TestCase):
|
||||
self.assertEqual(len(ret), 1)
|
||||
self.assertIs(ret[0], a.sqrt().sin()) # only rewrite
|
||||
|
||||
class TextVizProfiler(unittest.TestCase):
|
||||
def test_perfetto_node(self):
|
||||
prof = [ProfileRangeEvent(device='NV', name='E_2', st=decimal.Decimal(1000), en=decimal.Decimal(1010), is_copy=False),
|
||||
ProfileDeviceEvent(device='NV', comp_tdiff=decimal.Decimal(-1000), copy_tdiff=decimal.Decimal(-100))]
|
||||
|
||||
j = json.loads(to_perfetto(prof))
|
||||
|
||||
# Device regs always first
|
||||
self.assertEqual(j['traceEvents'][0]['name'], 'process_name')
|
||||
self.assertEqual(j['traceEvents'][0]['ph'], 'M')
|
||||
self.assertEqual(j['traceEvents'][0]['args']['name'], 'NV')
|
||||
|
||||
self.assertEqual(j['traceEvents'][1]['name'], 'thread_name')
|
||||
self.assertEqual(j['traceEvents'][1]['ph'], 'M')
|
||||
self.assertEqual(j['traceEvents'][1]['pid'], j['traceEvents'][0]['pid'])
|
||||
self.assertEqual(j['traceEvents'][1]['tid'], 0)
|
||||
self.assertEqual(j['traceEvents'][1]['args']['name'], 'COMPUTE')
|
||||
|
||||
self.assertEqual(j['traceEvents'][2]['name'], 'thread_name')
|
||||
self.assertEqual(j['traceEvents'][2]['ph'], 'M')
|
||||
self.assertEqual(j['traceEvents'][2]['pid'], j['traceEvents'][0]['pid'])
|
||||
self.assertEqual(j['traceEvents'][2]['tid'], 1)
|
||||
self.assertEqual(j['traceEvents'][2]['args']['name'], 'COPY')
|
||||
|
||||
self.assertEqual(j['traceEvents'][3]['name'], 'E_2')
|
||||
self.assertEqual(j['traceEvents'][3]['ts'], 0)
|
||||
self.assertEqual(j['traceEvents'][3]['dur'], 10)
|
||||
self.assertEqual(j['traceEvents'][3]['ph'], 'X')
|
||||
self.assertEqual(j['traceEvents'][3]['pid'], j['traceEvents'][0]['pid'])
|
||||
self.assertEqual(j['traceEvents'][3]['tid'], 0)
|
||||
|
||||
def test_perfetto_copy_node(self):
|
||||
prof = [ProfileRangeEvent(device='NV', name='COPYxx', st=decimal.Decimal(1000), en=decimal.Decimal(1010), is_copy=True),
|
||||
ProfileDeviceEvent(device='NV', comp_tdiff=decimal.Decimal(-1000), copy_tdiff=decimal.Decimal(-100))]
|
||||
|
||||
j = json.loads(to_perfetto(prof))
|
||||
|
||||
self.assertEqual(j['traceEvents'][3]['name'], 'COPYxx')
|
||||
self.assertEqual(j['traceEvents'][3]['ts'], 900) # diff clock
|
||||
self.assertEqual(j['traceEvents'][3]['dur'], 10)
|
||||
self.assertEqual(j['traceEvents'][3]['ph'], 'X')
|
||||
self.assertEqual(j['traceEvents'][3]['tid'], 1)
|
||||
|
||||
def test_perfetto_graph(self):
|
||||
prof = [ProfileDeviceEvent(device='NV', comp_tdiff=decimal.Decimal(-1000), copy_tdiff=decimal.Decimal(-100)),
|
||||
ProfileDeviceEvent(device='NV:1', comp_tdiff=decimal.Decimal(-500), copy_tdiff=decimal.Decimal(-50)),
|
||||
ProfileGraphEvent(ents=[ProfileGraphEntry(device='NV', name='E_25_4n2', st_id=0, en_id=1, is_copy=False),
|
||||
ProfileGraphEntry(device='NV:1', name='NV -> NV:1', st_id=2, en_id=3, is_copy=True)],
|
||||
deps=[[], [0]],
|
||||
sigs=[decimal.Decimal(1000), decimal.Decimal(1002), decimal.Decimal(1004), decimal.Decimal(1008)])]
|
||||
|
||||
j = json.loads(to_perfetto(prof))
|
||||
|
||||
# Device regs always first
|
||||
self.assertEqual(j['traceEvents'][0]['args']['name'], 'NV')
|
||||
self.assertEqual(j['traceEvents'][1]['args']['name'], 'COMPUTE')
|
||||
self.assertEqual(j['traceEvents'][2]['args']['name'], 'COPY')
|
||||
self.assertEqual(j['traceEvents'][3]['args']['name'], 'NV:1')
|
||||
self.assertEqual(j['traceEvents'][4]['args']['name'], 'COMPUTE')
|
||||
self.assertEqual(j['traceEvents'][5]['args']['name'], 'COPY')
|
||||
|
||||
self.assertEqual(j['traceEvents'][6]['name'], 'E_25_4n2')
|
||||
self.assertEqual(j['traceEvents'][6]['ts'], 0)
|
||||
self.assertEqual(j['traceEvents'][6]['dur'], 2)
|
||||
self.assertEqual(j['traceEvents'][6]['pid'], j['traceEvents'][0]['pid'])
|
||||
|
||||
self.assertEqual(j['traceEvents'][7]['name'], 'NV -> NV:1')
|
||||
self.assertEqual(j['traceEvents'][7]['ts'], 954)
|
||||
self.assertEqual(j['traceEvents'][7]['dur'], 4)
|
||||
self.assertEqual(j['traceEvents'][7]['pid'], j['traceEvents'][3]['pid'])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user