display viz rewrites with tabbing if they are subrewrites (#10097)

* display viz rewrites with tabbing if they are subrewrites

* update viz api
This commit is contained in:
qazal
2025-04-29 12:57:21 +03:00
committed by GitHub
parent 73c2f6602f
commit cbf7347cd6
4 changed files with 25 additions and 3 deletions

View File

@@ -1,6 +1,6 @@
import unittest, decimal, json
from tinygrad.dtype import dtypes
from tinygrad.ops import TRACK_MATCH_STATS, TrackedPatternMatcher, UOp, graph_rewrite, track_rewrites, UPat
from tinygrad.ops import TRACK_MATCH_STATS, TrackedPatternMatcher, UOp, graph_rewrite, track_rewrites, UPat, Ops
from tinygrad.codegen.symbolic import symbolic
from tinygrad.ops import tracked_ctxs as contexts, tracked_keys as keys, _name_cnt, _substitute
from tinygrad.device import ProfileDeviceEvent, ProfileRangeEvent, ProfileGraphEvent, ProfileGraphEntry
@@ -14,6 +14,10 @@ inner_rewrite = TrackedPatternMatcher([
(UPat.cvar("x"), lambda x: None if x.dtype == dtypes.float32 else UOp.const(dtypes.float32, x.arg)),
])
l2 = TrackedPatternMatcher([(UPat(Ops.CUSTOM, arg=2, name="x"), lambda x: x.replace(arg=3))])
l1 = TrackedPatternMatcher([(UPat(Ops.CUSTOM, arg=1, name="x"), lambda x: graph_rewrite(x.replace(arg=2), l2))])
l0 = TrackedPatternMatcher([(UPat(Ops.CUSTOM, arg=0, name="x"), lambda x: graph_rewrite(x.replace(arg=1), l1))])
class TestViz(unittest.TestCase):
def setUp(self):
# clear the global context
@@ -170,10 +174,24 @@ class TestViz(unittest.TestCase):
self.assertEqual(len(contexts), 1)
tracked = contexts[0]
self.assertEqual(len(tracked), 3)
self.assertEqual(tracked[0].depth, 0)
self.assertEqual(tracked[1].depth, 1)
self.assertEqual(tracked[2].depth, 1)
# NOTE: this is sorted by the time called, maybe it should be by depth
self.assertEqual([x.name for x in tracked], ["outer", "inner_x", "inner_y"])
self.assertEqual([len(x.matches) for x in tracked], [1, 1, 1])
def test_depth_level(self):
@track_rewrites(named=True)
def fxn(u:UOp): return graph_rewrite(u, l0)
ret = fxn(UOp(Ops.CUSTOM, arg=0))
assert ret is UOp(Ops.CUSTOM, arg=3)
self.assertEqual(len(contexts), 1)
tracked = contexts[0]
self.assertEqual(tracked[0].depth, 0)
self.assertEqual(tracked[1].depth, 1)
self.assertEqual(tracked[2].depth, 2)
def test_shape_label(self):
a = UOp.new_buffer("CPU", 1, dtypes.uint8).expand((4,))
b = UOp.new_buffer("CPU", 1, dtypes.uint8).expand((8,))