mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
display viz rewrites with tabbing if they are subrewrites (#10097)
* display viz rewrites with tabbing if they are subrewrites * update viz api
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import unittest, decimal, json
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import TRACK_MATCH_STATS, TrackedPatternMatcher, UOp, graph_rewrite, track_rewrites, UPat
|
||||
from tinygrad.ops import TRACK_MATCH_STATS, TrackedPatternMatcher, UOp, graph_rewrite, track_rewrites, UPat, Ops
|
||||
from tinygrad.codegen.symbolic import symbolic
|
||||
from tinygrad.ops import tracked_ctxs as contexts, tracked_keys as keys, _name_cnt, _substitute
|
||||
from tinygrad.device import ProfileDeviceEvent, ProfileRangeEvent, ProfileGraphEvent, ProfileGraphEntry
|
||||
@@ -14,6 +14,10 @@ inner_rewrite = TrackedPatternMatcher([
|
||||
(UPat.cvar("x"), lambda x: None if x.dtype == dtypes.float32 else UOp.const(dtypes.float32, x.arg)),
|
||||
])
|
||||
|
||||
l2 = TrackedPatternMatcher([(UPat(Ops.CUSTOM, arg=2, name="x"), lambda x: x.replace(arg=3))])
|
||||
l1 = TrackedPatternMatcher([(UPat(Ops.CUSTOM, arg=1, name="x"), lambda x: graph_rewrite(x.replace(arg=2), l2))])
|
||||
l0 = TrackedPatternMatcher([(UPat(Ops.CUSTOM, arg=0, name="x"), lambda x: graph_rewrite(x.replace(arg=1), l1))])
|
||||
|
||||
class TestViz(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# clear the global context
|
||||
@@ -170,10 +174,24 @@ class TestViz(unittest.TestCase):
|
||||
self.assertEqual(len(contexts), 1)
|
||||
tracked = contexts[0]
|
||||
self.assertEqual(len(tracked), 3)
|
||||
self.assertEqual(tracked[0].depth, 0)
|
||||
self.assertEqual(tracked[1].depth, 1)
|
||||
self.assertEqual(tracked[2].depth, 1)
|
||||
# NOTE: this is sorted by the time called, maybe it should be by depth
|
||||
self.assertEqual([x.name for x in tracked], ["outer", "inner_x", "inner_y"])
|
||||
self.assertEqual([len(x.matches) for x in tracked], [1, 1, 1])
|
||||
|
||||
def test_depth_level(self):
|
||||
@track_rewrites(named=True)
|
||||
def fxn(u:UOp): return graph_rewrite(u, l0)
|
||||
ret = fxn(UOp(Ops.CUSTOM, arg=0))
|
||||
assert ret is UOp(Ops.CUSTOM, arg=3)
|
||||
self.assertEqual(len(contexts), 1)
|
||||
tracked = contexts[0]
|
||||
self.assertEqual(tracked[0].depth, 0)
|
||||
self.assertEqual(tracked[1].depth, 1)
|
||||
self.assertEqual(tracked[2].depth, 2)
|
||||
|
||||
def test_shape_label(self):
|
||||
a = UOp.new_buffer("CPU", 1, dtypes.uint8).expand((4,))
|
||||
b = UOp.new_buffer("CPU", 1, dtypes.uint8).expand((8,))
|
||||
|
||||
Reference in New Issue
Block a user