viz: one name arg in track_rewrites [pr] (#10873)

* viz: one name arg in track_rewrites [pr]

* other test
This commit is contained in:
qazal
2025-06-19 03:34:56 +03:00
committed by GitHub
parent 18593c9800
commit 5839542fc8
4 changed files with 19 additions and 19 deletions

View File

@@ -66,7 +66,7 @@ def _test_conv2d(allowed:int, dtype:DType=dtypes.float, **kwargs):
np.testing.assert_allclose(img.grad.numpy(), ref_img.grad.detach().numpy(), atol=1e-6 if dtype == dtypes.float else 1e-2)
np.testing.assert_allclose(w.grad.numpy(), ref_w.grad.detach().numpy(), atol=1e-6 if dtype == dtypes.float else 1e-2)
@track_rewrites(named=True)
@track_rewrites(name=True)
def schedule_graph_rewrite(big_sink:UOp): return get_kernelize_map(big_sink)[big_sink]
class TestSchedule(unittest.TestCase):

View File

@@ -30,7 +30,7 @@ class TestViz(unittest.TestCase):
def test_viz_simple(self):
a = UOp.variable("a", 0, 10)
@track_rewrites(named=True)
@track_rewrites(name=True)
def test(sink): return graph_rewrite(sink, symbolic)
test(a*1)
ret = get_metadata(keys, contexts)
@@ -41,7 +41,7 @@ class TestViz(unittest.TestCase):
def test_track_two_rewrites(self):
a = UOp.variable("a", 0, 10)
@track_rewrites(named=True)
@track_rewrites(name=True)
def test(sink): return graph_rewrite(sink, symbolic)
test((a+a)*1)
ret = get_metadata(keys, contexts)
@@ -53,7 +53,7 @@ class TestViz(unittest.TestCase):
def test_track_multiple_calls_one_ctx(self):
a = UOp.variable("a", 0, 10)
@track_rewrites(named=True)
@track_rewrites(name=True)
def test(a, b):
a = graph_rewrite(a, symbolic)
b = graph_rewrite(b, symbolic)
@@ -67,7 +67,7 @@ class TestViz(unittest.TestCase):
self.assertEqual(val[1]["match_count"], 0) # no rewrites for a*5
def test_track_rewrites(self):
@track_rewrites(named=True)
@track_rewrites(name=True)
def do_rewrite(x:UOp): return graph_rewrite(x, symbolic)
a = UOp.variable("a", 0, 10)
b = UOp.variable("b", 0, 4)
@@ -93,7 +93,7 @@ class TestViz(unittest.TestCase):
self.assertEqual(len(ret), 1)
def test_track_rewrites_name_fxn(self):
@track_rewrites(name_fxn=lambda _,ret: f"output_{ret}")
@track_rewrites(name=lambda _,ret: f"output_{ret}")
def do_rewrite(x:UOp):
x = graph_rewrite(x, symbolic)
return x.render()
@@ -109,7 +109,7 @@ class TestViz(unittest.TestCase):
@unittest.expectedFailure
def test_name_in_positional_arg(self):
@track_rewrites(named=True)
@track_rewrites(name=True)
def test(sink): return graph_rewrite(sink, symbolic, None, False, "name")
test(UOp.variable("a", 0, 1))
self.assertEqual(contexts[0].pop().name, "name")
@@ -125,7 +125,7 @@ class TestViz(unittest.TestCase):
a = UOp.variable("a", 0, 10)
b = UOp.variable("b", 0, 10)
c = UOp.variable("c", 0, 10)
@track_rewrites(named=True)
@track_rewrites(name=True)
def fxn(sink): return graph_rewrite(sink, substitute, ctx={a+b:c}, bottom_up=True)
fxn(a+b)
#UOp.substitute(a+b, {a+b:c})
@@ -137,7 +137,7 @@ class TestViz(unittest.TestCase):
# NOTE: calling graph_rewrite when the function isn't decorated with track_rewrites should not VIZ
def test_rewrite_without_context(self):
def untracked_graph_rewrite(sink): return graph_rewrite(sink, symbolic)
@track_rewrites(named=True)
@track_rewrites(name=True)
def tracked_graph_rewrite(sink): return graph_rewrite(sink, symbolic)
# test
add = UOp.const(dtypes.int, 2) + UOp.const(dtypes.int, 1)
@@ -149,7 +149,7 @@ class TestViz(unittest.TestCase):
def test_inner_rewrite_location(self):
# inner rewrite gets tracked in another context
def inner_rewrite(sink): return graph_rewrite(sink, symbolic)
@track_rewrites(named=True)
@track_rewrites(name=True)
def tracked_graph_rewrite(sink): return inner_rewrite(sink)
# test
add = UOp.const(dtypes.int, 2) + UOp.const(dtypes.int, 1)
@@ -171,7 +171,7 @@ class TestViz(unittest.TestCase):
y2 = graph_rewrite(y, inner_rewrite, name="inner_y")
return None if (x2 is x and y2 is y) else x2+y2
outer_rewrite = TrackedPatternMatcher([(UPat.cvar("x")+UPat.cvar("y"), make_float),])
@track_rewrites(named=True)
@track_rewrites(name=True)
def rewrite(u:UOp): return graph_rewrite(u, outer_rewrite, name="outer")
a = UOp.const(dtypes.int, 1)+UOp.const(dtypes.int, 2)
rewrite(a)
@@ -186,7 +186,7 @@ class TestViz(unittest.TestCase):
self.assertEqual([len(x.matches) for x in tracked], [1, 1, 1])
def test_depth_level(self):
@track_rewrites(named=True)
@track_rewrites(name=True)
def fxn(u:UOp): return graph_rewrite(u, l0)
ret = fxn(UOp(Ops.CUSTOM, arg=0))
assert ret is UOp(Ops.CUSTOM, arg=3)
@@ -216,10 +216,10 @@ class TestViz(unittest.TestCase):
@unittest.skip("TODO: doesn't work")
def test_recursion_err(self):
inf = TrackedPatternMatcher([
(UPat.const(dtypes.int, 0).named("a"), lambda a: a.const_like(1)),
(UPat.const(dtypes.int, 1).named("b"), lambda b: b.const_like(0)),
(UPat.const(dtypes.int, 0).name("a"), lambda a: a.const_like(1)),
(UPat.const(dtypes.int, 1).name("b"), lambda b: b.const_like(0)),
])
@track_rewrites(named=True)
@track_rewrites(name=True)
def func(u): return graph_rewrite(u, inf)
with self.assertRaises(RecursionError): func(UOp.const(dtypes.int, 0))
_ = list(get_details(keys[0], contexts[0][0]))

View File

@@ -420,7 +420,7 @@ finalize_gbarrier = PatternMatcher([
remove_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
@track_rewrites(name_fxn=lambda big_sink,ret: f"Schedule {pluralize('Kernel',len([u for u in ret[big_sink].toposort() if u.op is Ops.KERNEL]))}")
@track_rewrites(name=lambda big_sink,ret: f"Schedule {pluralize('Kernel',len([u for u in ret[big_sink].toposort() if u.op is Ops.KERNEL]))}")
def get_kernelize_map(big_sink:UOp) -> dict[UOp, UOp]:
# multi + merge_views + simplify
tensor_map = graph_rewrite_map(big_sink, multi_pm+do_fuse+merge_views+sym+replace_contiguous, ctx={}, name="merge_views")

View File

@@ -738,15 +738,15 @@ if getenv("CAPTURE_PROCESS_REPLAY"):
def save_to_diskcache():
for k,v in replay_capture.items(): diskcache_put("process_replay", k, v, prepickled=True)
def track_rewrites(named=False, name_fxn:Callable|None=None):
def track_rewrites(name:Callable|bool|None=None):
def _decorator(func):
def __wrapper(*args, **kwargs):
if TRACK_MATCH_STATS >= 2:
if (count_names:=(named or name_fxn or not args)): _name_cnt[func.__name__] = _name_cnt.get(func.__name__, 0)+1
if (count_names:=(name or not args)): _name_cnt[func.__name__] = _name_cnt.get(func.__name__, 0)+1
tracked_keys.append(f"{func.__name__}_{_name_cnt[func.__name__]}" if count_names else args[0])
tracked_ctxs.append([])
ret = func(*args, **kwargs)
if TRACK_MATCH_STATS >= 2 and name_fxn is not None: tracked_keys[-1] = f"{name_fxn(*args, **kwargs, ret=ret)} n{_name_cnt[func.__name__]}"
if TRACK_MATCH_STATS >= 2 and callable(name): tracked_keys[-1] = f"{name(*args, **kwargs, ret=ret)} n{_name_cnt[func.__name__]}"
if getenv("CAPTURE_PROCESS_REPLAY"):
# find the unittest frame we're capturing in
frm = sys._getframe(1)